Unverified Commit 9d44279e by Tianqi Chen Committed by GitHub

[RELAY][Refactor] TextPrinter, move ret_type after body in Function. (#1918)

parent 4c13ee22
...@@ -131,6 +131,13 @@ class BaseAttrsNode : public Node { ...@@ -131,6 +131,13 @@ class BaseAttrsNode : public Node {
*/ */
inline void PrintDocString(std::ostream &os) const; // NOLINT(*) inline void PrintDocString(std::ostream &os) const; // NOLINT(*)
/*! /*!
* \brief Visit attributes that do not equal the default value.
*
* \note This is useful to extract fields for concise printing.
* \param v The visitor
*/
TVM_DLL virtual void VisitNonDefaultAttrs(AttrVisitor* v) = 0;
/*!
* \brief Get the field information * \brief Get the field information
* \return The fields in the Attrs. * \return The fields in the Attrs.
*/ */
...@@ -199,6 +206,7 @@ class DictAttrsNode : public BaseAttrsNode { ...@@ -199,6 +206,7 @@ class DictAttrsNode : public BaseAttrsNode {
TVM_DLL static Attrs make(Map<std::string, NodeRef> dict); TVM_DLL static Attrs make(Map<std::string, NodeRef> dict);
// implementations // implementations
void VisitAttrs(AttrVisitor* v) final; void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
Array<AttrFieldInfo> ListFieldInfo() const final; Array<AttrFieldInfo> ListFieldInfo() const final;
bool ContentEqual(const Node* other) const final; bool ContentEqual(const Node* other) const final;
...@@ -300,15 +308,15 @@ struct AttrNopEntry { ...@@ -300,15 +308,15 @@ struct AttrNopEntry {
return *this; return *this;
} }
template<typename T> template<typename T>
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED T value) { TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
return *this; return *this;
} }
template<typename T> template<typename T>
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) { TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
return *this; return *this;
} }
template<typename T> template<typename T>
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) { TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
return *this; return *this;
} }
}; };
...@@ -603,7 +611,7 @@ class AttrDocEntry { ...@@ -603,7 +611,7 @@ class AttrDocEntry {
return *this; return *this;
} }
template<typename T> template<typename T>
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED T value) { TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
std::ostringstream os; std::ostringstream os;
os << info_->type_info << ", default=" << value; os << info_->type_info << ", default=" << value;
info_->type_info = os.str(); info_->type_info = os.str();
...@@ -649,6 +657,57 @@ class AttrExistVisitor { ...@@ -649,6 +657,57 @@ class AttrExistVisitor {
return AttrNopEntry(); return AttrNopEntry();
} }
}; };
template<typename T>
struct AttrTriggerNonDefaultEntry {
using TSelf = AttrTriggerNonDefaultEntry<T>;
// constructor
AttrTriggerNonDefaultEntry(
AttrVisitor* visitor, const char* key, T* data)
: visitor_(visitor), key_(key), data_(data) {}
~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION {
if (trigger_) {
visitor_->Visit(key_, data_);
}
}
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
return *this;
}
TSelf& set_default(const T& value) {
if (AttrsEqual()(value, *data_)) {
trigger_ = false;
}
return *this;
}
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
return *this;
}
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
return *this;
}
private:
AttrVisitor* visitor_;
const char * key_;
T *data_;
bool trigger_{true};
};
class AttrNonDefaultVisitor {
public:
explicit AttrNonDefaultVisitor(AttrVisitor* visitor)
: visitor_(visitor) {
}
template<typename T>
AttrTriggerNonDefaultEntry<T>
operator()(const char* key, T* value) {
return AttrTriggerNonDefaultEntry<T>(visitor_, key, value);
}
private:
AttrVisitor* visitor_;
};
} // namespace detail } // namespace detail
/*! /*!
...@@ -665,6 +724,11 @@ class AttrsNode : public BaseAttrsNode { ...@@ -665,6 +724,11 @@ class AttrsNode : public BaseAttrsNode {
self()->__VisitAttrs__(vis); self()->__VisitAttrs__(vis);
} }
void VisitNonDefaultAttrs(AttrVisitor* v) final {
detail::AttrNonDefaultVisitor vis(v);
self()->__VisitAttrs__(vis);
}
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final { void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {
CHECK_EQ(args.size() % 2, 0); CHECK_EQ(args.size() % 2, 0);
const int kLinearSearchBound = 16; const int kLinearSearchBound = 16;
......
...@@ -13,7 +13,7 @@ namespace tvm { ...@@ -13,7 +13,7 @@ namespace tvm {
namespace relay { namespace relay {
/*! \brief Attributes used in convolution operators */ /*! \brief Attributes used in convolution operators */
struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> { struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
Array<IndexExpr> strides; Array<IndexExpr> strides;
Array<IndexExpr> padding; Array<IndexExpr> padding;
Array<IndexExpr> dilation; Array<IndexExpr> dilation;
...@@ -25,7 +25,7 @@ struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> { ...@@ -25,7 +25,7 @@ struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> {
std::string out_layout; std::string out_layout;
DataType out_dtype; DataType out_dtype;
TVM_DECLARE_ATTRS(ConvAttrs, "relay.attrs.ConvAttrs") { TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1})) TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution."); .describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0})) TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
...@@ -55,14 +55,14 @@ struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> { ...@@ -55,14 +55,14 @@ struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> {
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc." .describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively."); "dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("__undef__") TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc." .describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width" "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout."); "dimensions respectively. Default to be same as input layout.");
// use 0 bits to indicate none. // use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype) TVM_ATTR_FIELD(out_dtype)
.set_default(Int(0)) .set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting"); .describe("Output data type, set to explicit type under mixed precision setting");
} }
}; };
...@@ -123,7 +123,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> { ...@@ -123,7 +123,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively."); "dimensions respectively.");
TVM_ATTR_FIELD(out_dtype) TVM_ATTR_FIELD(out_dtype)
.set_default(Int(0)) .set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting"); .describe("Output data type, set to explicit type under mixed precision setting");
} }
}; };
......
...@@ -78,7 +78,7 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> { ...@@ -78,7 +78,7 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
.describe("Target shape."); .describe("Target shape.");
TVM_ATTR_FIELD(dtype) TVM_ATTR_FIELD(dtype)
.describe("Target data type.") .describe("Target data type.")
.set_default(Int(0)); .set_default(NullValue<DataType>());
} }
}; // struct InitOpAttrs }; // struct InitOpAttrs
......
...@@ -181,8 +181,6 @@ class FunctionNode : public ExprNode { ...@@ -181,8 +181,6 @@ class FunctionNode : public ExprNode {
public: public:
/*! \brief Function parameters */ /*! \brief Function parameters */
tvm::Array<Var> params; tvm::Array<Var> params;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*! /*!
* \brief * \brief
* The expression which represents the computation of the function, * The expression which represents the computation of the function,
...@@ -190,6 +188,8 @@ class FunctionNode : public ExprNode { ...@@ -190,6 +188,8 @@ class FunctionNode : public ExprNode {
* or sub-expressions may reference the type variables. * or sub-expressions may reference the type variables.
*/ */
Expr body; Expr body;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*! /*!
* \brief Type parameters of the function. * \brief Type parameters of the function.
* Enables the function to vary its type based on these. * Enables the function to vary its type based on these.
...@@ -201,8 +201,8 @@ class FunctionNode : public ExprNode { ...@@ -201,8 +201,8 @@ class FunctionNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("params", &params); v->Visit("params", &params);
v->Visit("ret_type", &ret_type);
v->Visit("body", &body); v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params); v->Visit("type_params", &type_params);
v->Visit("span", &span); v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
...@@ -217,8 +217,8 @@ class FunctionNode : public ExprNode { ...@@ -217,8 +217,8 @@ class FunctionNode : public ExprNode {
TVM_DLL FuncType func_type_annotation() const; TVM_DLL FuncType func_type_annotation() const;
TVM_DLL static Function make(tvm::Array<Var> params, TVM_DLL static Function make(tvm::Array<Var> params,
Type ret_type,
Expr body, Expr body,
Type ret_type,
tvm::Array<TypeParam> ty_params); tvm::Array<TypeParam> ty_params);
static constexpr const char* _type_key = "relay.Function"; static constexpr const char* _type_key = "relay.Function";
......
...@@ -48,6 +48,11 @@ class OpNode : public relay::ExprNode { ...@@ -48,6 +48,11 @@ class OpNode : public relay::ExprNode {
*/ */
std::string attrs_type_key; std::string attrs_type_key;
/*! /*!
* \brief attribute type index,
* this field varies in each run and is not exposed to frontend.
*/
uint32_t attrs_type_index{0};
/*!
* \brief number of input arguments to the operator, * \brief number of input arguments to the operator,
* -1 means it is variable length * -1 means it is variable length
*/ */
...@@ -416,6 +421,7 @@ inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*) ...@@ -416,6 +421,7 @@ inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*)
inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*) inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*)
const std::string& type_key) { const std::string& type_key) {
get()->attrs_type_key = type_key; get()->attrs_type_key = type_key;
get()->attrs_type_index = Node::TypeKey2Index(type_key.c_str());
return *this; return *this;
} }
......
...@@ -14,13 +14,15 @@ from . import libinfo ...@@ -14,13 +14,15 @@ from . import libinfo
#---------------------------- #----------------------------
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
string_types = (str,) string_types = (str,)
numeric_types = (float, int, np.float32, np.int32) integer_types = (int, np.int32)
numeric_types = integer_types + (float, np.float32)
# this function is needed for python3 # this function is needed for python3
# to convert ctypes.char_p .value back to python str # to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8') py_str = lambda x: x.decode('utf-8')
else: else:
string_types = (basestring,) string_types = (basestring,)
numeric_types = (float, int, long, np.float32, np.int32) integer_types = (int, long, np.int32)
numeric_types = integer_types + (float, np.float32)
py_str = lambda x: x py_str = lambda x: x
......
# pylint: disable=wildcard-import, redefined-builtin # pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler.""" """The Relay IR namespace containing the IR definition and compiler."""
from . import base from . import base
from . import ty from . import ty
...@@ -19,6 +19,9 @@ from . import image ...@@ -19,6 +19,9 @@ from . import image
# Span # Span
Span = base.Span Span = base.Span
# Env
Environment = env.Environment
# Type # Type
Type = ty.Type Type = ty.Type
TupleType = ty.TupleType TupleType = ty.TupleType
...@@ -40,3 +43,7 @@ Call = expr.Call ...@@ -40,3 +43,7 @@ Call = expr.Call
Let = expr.Let Let = expr.Let
If = expr.If If = expr.If
TupleGetItem = expr.TupleGetItem TupleGetItem = expr.TupleGetItem
# helper functions
var = expr.var
const = expr.const
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .._ffi.node import NodeBase, register_node as _register_tvm_node from .._ffi.node import NodeBase, register_node as _register_tvm_node
from . import _make from . import _make
from . import _expr
NodeBase = NodeBase NodeBase = NodeBase
...@@ -20,7 +21,19 @@ def register_relay_node(type_key=None): ...@@ -20,7 +21,19 @@ def register_relay_node(type_key=None):
return _register_tvm_node(type_key) return _register_tvm_node(type_key)
class RelayNode(NodeBase):
def astext(self):
"""Get the text format of the expression.
Returns
-------
text : str
The text format of the expression.
"""
return _expr._text_print(self)
@register_relay_node @register_relay_node
class Span(NodeBase): class Span(RelayNode):
def __init__(self, source, lineno, col_offset): def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset) self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global environment storing everything needed to interpret or compile a Relay program.""" """A global environment storing everything needed to interpret or compile a Relay program."""
from .base import register_relay_node, NodeBase from .base import register_relay_node, RelayNode
from . import _make from . import _make
from . import _env from . import _env
@register_relay_node @register_relay_node
class Environment(NodeBase): class Environment(RelayNode):
"""The global Relay environment containing functions, """The global Relay environment containing functions,
options and more. options and more.
""" """
......
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression nodes of Relay.""" """The expression nodes of Relay."""
from __future__ import absolute_import from __future__ import absolute_import
from .base import NodeBase, register_relay_node
from . import _expr import numpy as _np
from .base import RelayNode, register_relay_node
from . import _make from . import _make
from . import ty as _ty
from .._ffi import base as _base, node as _node
from .. import nd as _nd
from .. import convert from .. import convert
class Expr(NodeBase): class Expr(RelayNode):
"""The base type for all Relay expressions.""" """The base type for all Relay expressions."""
@property @property
def checked_type(self): def checked_type(self):
...@@ -56,7 +60,7 @@ class Tuple(Expr): ...@@ -56,7 +60,7 @@ class Tuple(Expr):
@register_relay_node @register_relay_node
class Var(Expr): class Var(Expr):
"""A local variable in Tvm.Relay. """A local variable in Relay.
Local variable can be used to declare input Local variable can be used to declare input
arguments to a function, or intermediate variables. arguments to a function, or intermediate variables.
...@@ -101,26 +105,26 @@ class Function(Expr): ...@@ -101,26 +105,26 @@ class Function(Expr):
params: List[tvm.relay.Var] params: List[tvm.relay.Var]
List of input parameters to the function. List of input parameters to the function.
ret_type: tvm.relay.Type
The return type annotation of the function.
body: tvm.relay.Expr body: tvm.relay.Expr
The body of the function. The body of the function.
ret_type: Optional[tvm.relay.Type]
The return type annotation of the function.
type_params: Optional[List[tvm.relay.TypeParam]] type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only The additional type parameters, this is only
used in advanced usecase of template functions. used in advanced usecase of template functions.
""" """
def __init__(self, def __init__(self,
params, params,
ret_type,
body, body,
ret_type=None,
type_params=None): type_params=None):
if type_params is None: if type_params is None:
type_params = convert([]) type_params = convert([])
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Function, params, ret_type, body, type_params) _make.Function, params, body, ret_type, type_params)
@register_relay_node @register_relay_node
...@@ -158,7 +162,7 @@ class Let(Expr): ...@@ -158,7 +162,7 @@ class Let(Expr):
Parameters Parameters
---------- ----------
var: tvm.relay.Var variable: tvm.relay.Var
The local variable to be bound. The local variable to be bound.
value: tvm.relay.Expr value: tvm.relay.Expr
...@@ -167,9 +171,9 @@ class Let(Expr): ...@@ -167,9 +171,9 @@ class Let(Expr):
body: tvm.relay.Expr body: tvm.relay.Expr
The body of the let binding. The body of the let binding.
""" """
def __init__(self, var, value, body): def __init__(self, variable, value, body):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Let, var, value, body) _make.Let, variable, value, body)
@register_relay_node @register_relay_node
...@@ -208,4 +212,105 @@ class TupleGetItem(Expr): ...@@ -208,4 +212,105 @@ class TupleGetItem(Expr):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.TupleGetItem, tuple_value, index) _make.TupleGetItem, tuple_value, index)
debug_print = _expr._debug_print
class TupleWrapper(_node.NodeGeneric):
"""TupleWrapper.
This class is a Python wrapper for a Relay tuple of known size.
It allows for accessing the fields of the Relay tuple as though
it were a Python tuple.
Parameters
----------
tuple_value: tvm.relay.Expr
The input tuple
size: int
The size of the tuple.
"""
def __init__(self, tuple_value, size):
self.tuple_value = tuple_value
self.size = size
def asnode(self):
"""Returns the underlying Relay tuple if this wrapper is passed
as an argument to an FFI function."""
return self.tuple_value
def __getitem__(self, key):
return self.tuple_value.fields[key]
def __len__(self):
return len(self.tuple_value.fields)
def var(name_hint,
type_annotation=None,
shape=None,
dtype="float32"):
"""Create a new tvm.relay.Var.
This is a simple wrapper function that allows specify
shape and dtype directly.
Parameters
----------
name_hint: str
The name of the variable.
This name only acts as a hint, and is not used
for equality.
type_annotation: Optional[tvm.relay.Type, str]
The type annotation on the variable.
When type_annotation is a str, we will create a scalar variable.
shape: Optional[List[tvm.Expr]]
The shape of the tensor type.
dtype: str, optional
The data type of the tensor.
Examples
--------
.. code-block:: python
# The following 4 lines are equivalent to each other
x = tvm.relay.Var("x", tvm.relay.TensorType([1, 2]))
x = tvm.relay.var("x", tvm.relay.TensorType([1, 2]))
x = tvm.relay.var("x", shape=[1, 2])
x = tvm.relay.var("x", shape=[1, 2], dtype="float32")
# The following 2 lines are equivalent to each other.
y = tvm.relay.var("x", "float32")
y = tvm.relay.var("x", shape=(), dtype="float32")
"""
if type_annotation is not None and shape is not None:
raise ValueError("Can only specify either type_annotation or shape.")
if shape is not None:
type_annotation = _ty.TensorType(shape, dtype)
elif isinstance(type_annotation, str):
type_annotation = _ty.TensorType((), type_annotation)
return Var(name_hint, type_annotation)
def const(value, dtype=None):
"""Create a constant value.
Parameters
----------
value: Union[bool, int, float, numpy.ndarray, tvm.nd.NDArray]
The constant value.
dtype: str, optional
The data type of the value.
"""
if isinstance(value, _base.numeric_types):
value = _np.array(value, dtype=dtype)
elif isinstance(value, (bool, list)):
value = _np.array(value, dtype=dtype)
if isinstance(value, (_np.ndarray, _np.generic)):
value = _nd.array(value)
if not isinstance(value, _nd.NDArray):
raise ValueError("value has to be scalar or NDArray")
return Constant(value)
...@@ -11,32 +11,6 @@ from .expr import Expr, Constant, Let, Var, Function, If ...@@ -11,32 +11,6 @@ from .expr import Expr, Constant, Let, Var, Function, If
from .env import Environment from .env import Environment
class TupleWrapper(tvm._ffi.node.NodeGeneric):
"""TupleWrapper.
This class is a Python wrapper for a Relay tuple of known size.
It allows for accessing the fields of the Relay tuple as though
it were a Python tuple.
"""
def __init__(self, tuple_value, size):
self.tuple_value = tuple_value
self.size = size
def asnode(self):
"""Returns the underlying Relay tuple if this wrapper is passed
as an argument to an FFI function."""
return self.tuple_value
def __getitem__(self, key):
return self.tuple_value.fields[key]
def __len__(self):
return len(self.tuple_value.fields)
def _convert_to_value(arg, ctxt=tvm.cpu(0)): def _convert_to_value(arg, ctxt=tvm.cpu(0)):
# type: (Any, tvm.Context) -> tvm.nd.NDArray # type: (Any, tvm.Context) -> tvm.nd.NDArray
"""Convert Python values into the appropriate types """Convert Python values into the appropriate types
...@@ -132,8 +106,8 @@ class PartialFunc(object): ...@@ -132,8 +106,8 @@ class PartialFunc(object):
"""Converts a PartialFunc into a :py:class:`~relay.Function`.""" """Converts a PartialFunc into a :py:class:`~relay.Function`."""
return Function( return Function(
self.params, self.params,
self.ret_type,
self.body, self.body,
self.ret_type,
self.type_params) self.type_params)
#pylint: disable=invalid-name #pylint: disable=invalid-name
...@@ -325,7 +299,7 @@ class IRBuilder(object): ...@@ -325,7 +299,7 @@ class IRBuilder(object):
def _on_exit(): def _on_exit():
bindings, _, _, ret_value = self.exit_scope() bindings, _, _, ret_value = self.exit_scope()
exp = _mk_let(bindings, ret_value) exp = _mk_let(bindings, ret_value)
self.env.add(name, Function(params, ret_type, exp)) self.env.add(name, Function(params, exp, ret_type))
return WithScope(10, _on_exit) return WithScope(10, _on_exit)
......
"""Neural network operations.""" """Neural network operations."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from tvm.relay.ir_builder import TupleWrapper from ...expr import TupleWrapper
from . import _make from . import _make
...@@ -145,7 +145,7 @@ def conv2d_transpose(data, ...@@ -145,7 +145,7 @@ def conv2d_transpose(data,
weight_layout, output_padding, out_dtype) weight_layout, output_padding, out_dtype)
def softmax(data, axis): def softmax(data, axis=1):
r"""Computes softmax. r"""Computes softmax.
.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)} .. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
...@@ -158,7 +158,7 @@ def softmax(data, axis): ...@@ -158,7 +158,7 @@ def softmax(data, axis):
data: relay.Expr data: relay.Expr
The input data to the operator. The input data to the operator.
axis: int axis: int, optional
The axis to sum over when computing softmax The axis to sum over when computing softmax
Returns Returns
......
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language.""" """The type nodes of the Relay language."""
from enum import IntEnum from enum import IntEnum
from .base import NodeBase, register_relay_node from .base import RelayNode, register_relay_node
from . import _make from . import _make
class Type(NodeBase): class Type(RelayNode):
"""The base type for all Relay types.""" """The base type for all Relay types."""
def __eq__(self, other): def __eq__(self, other):
...@@ -21,27 +21,25 @@ class Type(NodeBase): ...@@ -21,27 +21,25 @@ class Type(NodeBase):
"""Compares two Relay types by referential equality.""" """Compares two Relay types by referential equality."""
return super().__eq__(other) return super().__eq__(other)
@register_relay_node @register_relay_node
class TensorType(Type): class TensorType(Type):
"""A concrete TensorType in Relay, see tvm/relay/type.h for more details. """A concrete TensorType in Relay.
This is the type assigned to tensor's with a known dype and shape. For This is the type assigned to tensor's with a known dype and shape. For
example a tensor of `float32` and `(5, 5)`. example a tensor of `float32` and `(5, 5)`.
"""
def __init__(self, shape, dtype):
"""Construct a tensor type.
Parameters Parameters
---------- ----------
shape: list of tvm.Expr shape: List[tvm.Expr]
dtype: str The shape of the Tensor
Returns dtype: str, optional
------- The content data type.
tensor_type: The TensorType """
""" def __init__(self, shape, dtype="float32"):
self.__init_handle_by_constructor__(_make.TensorType, shape, dtype) self.__init_handle_by_constructor__(
_make.TensorType, shape, dtype)
class Kind(IntEnum): class Kind(IntEnum):
......
...@@ -17,11 +17,15 @@ namespace tvm { ...@@ -17,11 +17,15 @@ namespace tvm {
template <typename FType> template <typename FType>
class AttrFunctor; class AttrFunctor;
#define ATTR_FUNCTOR_DEFAULT \
{ return VisitAttrDefault_(op, std::forward<Args>(args)...); }
#define ATTR_FUNCTOR_DISPATCH(OP) \ #define ATTR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \ vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \ [](const NodeRef& n, TSelf* self, Args... args) { \
return self->Visit_(static_cast<const OP*>(n.node_.get()), \ return self->VisitAttr_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \ std::forward<Args>(args)...); \
}); \ }); \
// A functor for common attribute information. // A functor for common attribute information.
...@@ -40,21 +44,21 @@ class AttrFunctor<R(const NodeRef& n, Args...)> { ...@@ -40,21 +44,21 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
* \param args Additional arguments. * \param args Additional arguments.
* \return The result of the call * \return The result of the call
*/ */
virtual R Visit(const NodeRef& n, Args... args) { virtual R VisitAttr(const NodeRef& n, Args... args) {
static FType vtable = InitVTable(); static FType vtable = InitVTable();
if (vtable.can_dispatch(n)) { if (vtable.can_dispatch(n)) {
return vtable(n, this, std::forward<Args>(args)...); return vtable(n, this, std::forward<Args>(args)...);
} else { } else {
return VisitDefault_(n, std::forward<Args>(args)...); return VisitAttrDefault_(n.get(), std::forward<Args>(args)...);
} }
} }
virtual R Visit_(const ArrayNode* op, Args... args) = 0; virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R Visit_(const StrMapNode* op, Args... args) = 0; virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R Visit_(const ir::IntImm* op, Args... args) = 0; virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R Visit_(const ir::UIntImm* op, Args... args) = 0; virtual R VisitAttr_(const ir::UIntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R Visit_(const ir::FloatImm* op, Args... args) = 0; virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R Visit_(const ir::StringImm* op, Args... args) = 0; virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitDefault_(const NodeRef& n, Args... args) = 0; virtual R VisitAttrDefault_(const Node* node, Args... args) = 0;
private: private:
// initialize the vtable. // initialize the vtable.
......
...@@ -11,6 +11,10 @@ void DictAttrsNode::VisitAttrs(AttrVisitor* v) { ...@@ -11,6 +11,10 @@ void DictAttrsNode::VisitAttrs(AttrVisitor* v) {
v->Visit("__dict__", &dict); v->Visit("__dict__", &dict);
} }
void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) {
v->Visit("__dict__", &dict);
}
void DictAttrsNode::InitByPackedArgs( void DictAttrsNode::InitByPackedArgs(
const runtime::TVMArgs& args, bool allow_unknown) { const runtime::TVMArgs& args, bool allow_unknown) {
for (int i = 0; i < args.size(); i += 2) { for (int i = 0; i < args.size(); i += 2) {
...@@ -55,48 +59,48 @@ class AttrsEqualChecker : ...@@ -55,48 +59,48 @@ class AttrsEqualChecker :
if (!equal_) return false; if (!equal_) return false;
if (lhs.same_as(rhs)) return true; if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false; if (!lhs.defined() || !rhs.defined()) return false;
if (!this->Visit(lhs, rhs)) { if (!this->VisitAttr(lhs, rhs)) {
equal_ = false; equal_ = false;
} }
return equal_; return equal_;
} }
bool VisitDefault_(const NodeRef& lhs, const NodeRef& other) final { bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final {
if (lhs->derived_from<BaseAttrsNode>()) { if (lhs->derived_from<BaseAttrsNode>()) {
return static_cast<const BaseAttrsNode*>(lhs.get())->ContentEqual(other.get()); return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(other.get());
} }
return lhs.same_as(other); return lhs == other.get();
} }
bool Visit_(const IntImm* lhs, const NodeRef& other) final { bool VisitAttr_(const IntImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<IntImm>()) { if (const auto* rhs = other.as<IntImm>()) {
return lhs->value == rhs->value; return lhs->value == rhs->value;
} }
return false; return false;
} }
bool Visit_(const UIntImm* lhs, const NodeRef& other) final { bool VisitAttr_(const UIntImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<UIntImm>()) { if (const auto* rhs = other.as<UIntImm>()) {
return lhs->value == rhs->value; return lhs->value == rhs->value;
} }
return false; return false;
} }
bool Visit_(const FloatImm* lhs, const NodeRef& other) final { bool VisitAttr_(const FloatImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<FloatImm>()) { if (const auto* rhs = other.as<FloatImm>()) {
return lhs->value == rhs->value; return lhs->value == rhs->value;
} }
return false; return false;
} }
bool Visit_(const StringImm* lhs, const NodeRef& other) final { bool VisitAttr_(const StringImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<StringImm>()) { if (const auto* rhs = other.as<StringImm>()) {
return lhs->value == rhs->value; return lhs->value == rhs->value;
} }
return false; return false;
} }
bool Visit_(const ArrayNode* lhs, const NodeRef& other) final { bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<ArrayNode>()) { if (const auto* rhs = other.as<ArrayNode>()) {
if (rhs->data.size() != lhs->data.size()) return false; if (rhs->data.size() != lhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) { for (size_t i = 0; i < lhs->data.size(); ++i) {
...@@ -106,7 +110,7 @@ class AttrsEqualChecker : ...@@ -106,7 +110,7 @@ class AttrsEqualChecker :
return true; return true;
} }
bool Visit_(const StrMapNode* lhs, const NodeRef& other) final { bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<StrMapNode>()) { if (const auto* rhs = other.as<StrMapNode>()) {
if (rhs->data.size() != lhs->data.size()) return false; if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) { for (const auto& kv : lhs->data) {
...@@ -127,38 +131,38 @@ class AttrContentHasher : ...@@ -127,38 +131,38 @@ class AttrContentHasher :
public: public:
size_t result_{0}; size_t result_{0};
void VisitDefault_(const NodeRef& value) final { void VisitAttrDefault_(const Node* value) final {
if (value->derived_from<BaseAttrsNode>()) { if (value->derived_from<BaseAttrsNode>()) {
Update(static_cast<const BaseAttrsNode*>(value.get())->ContentHash()); Update(static_cast<const BaseAttrsNode*>(value)->ContentHash());
} else { } else {
Update(NodeHash()(value)); Update(NodeHash()(GetRef<NodeRef>(value)));
} }
} }
void Visit_(const IntImm* op) final { void VisitAttr_(const IntImm* op) final {
Update(std::hash<int64_t>()(op->value)); Update(std::hash<int64_t>()(op->value));
} }
void Visit_(const UIntImm* op) final { void VisitAttr_(const UIntImm* op) final {
Update(std::hash<uint64_t>()(op->value)); Update(std::hash<uint64_t>()(op->value));
} }
void Visit_(const FloatImm* op) final { void VisitAttr_(const FloatImm* op) final {
Update(std::hash<double>()(op->value)); Update(std::hash<double>()(op->value));
} }
void Visit_(const StringImm* op) final { void VisitAttr_(const StringImm* op) final {
Update(std::hash<std::string>()(op->value)); Update(std::hash<std::string>()(op->value));
} }
void Visit_(const ArrayNode* op) final { void VisitAttr_(const ArrayNode* op) final {
Update(op->data.size()); Update(op->data.size());
for (size_t i = 0; i < op->data.size(); ++i) { for (size_t i = 0; i < op->data.size(); ++i) {
this->Visit(NodeRef(op->data[i])); this->VisitAttr(NodeRef(op->data[i]));
} }
} }
void Visit_(const StrMapNode* lhs) final { void VisitAttr_(const StrMapNode* lhs) final {
using Entry = std::pair<std::string, NodePtr<Node> >; using Entry = std::pair<std::string, NodePtr<Node> >;
std::vector<Entry> data(lhs->data.begin(), lhs->data.end()); std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) { std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
...@@ -166,7 +170,7 @@ class AttrContentHasher : ...@@ -166,7 +170,7 @@ class AttrContentHasher :
}); });
for (const Entry& kv : data) { for (const Entry& kv : data) {
Update(std::hash<std::string>()(kv.first)); Update(std::hash<std::string>()(kv.first));
this->Visit(NodeRef(kv.second)); this->VisitAttr(NodeRef(kv.second));
} }
} }
...@@ -184,7 +188,7 @@ bool AttrsEqual::Equal(const NodeRef& lhs, const NodeRef& rhs) { ...@@ -184,7 +188,7 @@ bool AttrsEqual::Equal(const NodeRef& lhs, const NodeRef& rhs) {
size_t AttrsHash::Hash(const NodeRef& node) { size_t AttrsHash::Hash(const NodeRef& node) {
if (!node.defined()) return 0; if (!node.defined()) return 0;
AttrContentHasher hasher; AttrContentHasher hasher;
hasher.Visit(node); hasher.VisitAttr(node);
return hasher.result_; return hasher.result_;
} }
......
...@@ -208,6 +208,8 @@ class JSONAttrGetter : public AttrVisitor { ...@@ -208,6 +208,8 @@ class JSONAttrGetter : public AttrVisitor {
node_->type_key = node->type_key(); node_->type_key = node->type_key();
// sepcially handle global object // sepcially handle global object
auto* f = dmlc::Registry<NodeFactoryReg>::Find(node_->type_key); auto* f = dmlc::Registry<NodeFactoryReg>::Find(node_->type_key);
CHECK(f != nullptr)
<< "Node type \'" << node_->type_key << "\' is not registered in TVM";
if (f->fglobal_key != nullptr) { if (f->fglobal_key != nullptr) {
node_->global_key = f->fglobal_key(node); node_->global_key = f->fglobal_key(node);
return; return;
......
...@@ -51,6 +51,8 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) { ...@@ -51,6 +51,8 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) {
return Span(n); return Span(n);
} }
TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_API("relay._make.Span") TVM_REGISTER_API("relay._make.Span")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = SpanNode::make(args[0], args[1], args[2]); *ret = SpanNode::make(args[0], args[1], args[2]);
......
/*!
* Copyright (c) 2018 by Contributors
* \file debug_printer.cc
* \brief A pretty printer for the Relay IR.
* As we had not determined a formal syntax yet, right now it is only for debug purpose.
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/environment.h>
#include <tvm/relay/error.h>
#include <iostream>
#include <sstream>
#include <vector>
#include <unordered_map>
#include <string>
#include <vector>
#include <iostream>
#include "../pass/type_functor.h"
#include "doc.h"
namespace tvm {
namespace relay {
using namespace tvm::runtime;
Doc KindDocify(TypeParamNode::Kind k) {
switch (k) {
case TypeParamNode::kShapeVar:
return DocOfStr("ShapeVar");
case TypeParamNode::kShape:
return DocOfStr("Shape");
case TypeParamNode::kBaseType:
return DocOfStr("BaseType");
case TypeParamNode::kType:
return DocOfStr("Type");
default:
LOG(FATAL) << "unreachable code: case not handle in kind";
throw; // log fatal throw but compiler doesnt know
}
}
template<typename T>
std::vector<Doc> MapDocify(const tvm::Array<T>& arr, const std::function<Doc(const T&)>& f) {
std::vector<Doc> vec;
for (size_t i = 0; i < arr.size(); ++i) {
vec.push_back(f(arr[i]));
}
return vec;
}
template<typename T, typename Hash = std::hash<T>, typename Eq = std::equal_to<T>>
class Counter {
std::unordered_map<T, size_t, Hash, Eq> cnt_;
public:
Counter() = default;
Counter(const Counter&) = delete;
size_t operator()(const T& t) {
auto v = cnt_.count(t) == 0 ? 0 : cnt_.at(t) + 1;
cnt_[t] = v;
return v;
}
};
std::string Mangle(const std::string& str, size_t s) {
return str + "_" + std::to_string(s);
// return s == 0 ? str : str + "_" + std::to_string(s - 1);
// the above line look prettier but is dangerous:
// suppose we have x, x, x_0. mangling will give x, x_0, x_0!
// the save approach give x_0, x_1, x_0_1, and in fact never clash:
// stripping _([0-9]*) is invert of mangle under all circumstances.
// another problem is we need to prevent Var/TypeParam/GlobalVar clashing each other.
}
constexpr size_t indent = 2;
struct TypeParamName {
bool operator==(const TypeParamName&) const {
return true;
}
};
struct mhash {
size_t operator()(const ::tvm::relay::TypeParamName&) const noexcept {
return 0;
}
};
class TypeDocifier : private TypeFunctor<Doc(const Type& n)> {
Environment env;
Counter<TypeParamName, mhash> cnt;
std::unordered_map<TypeParam, Doc, NodeHash, NodeEqual> map;
std::vector<Doc> DocifyTypeArray(const tvm::Array<Type>& arr) {
return MapDocify<Type>(arr, [=](const Type& t) { return Docify(t); });
}
std::vector<Doc> DocifyTypeParam(const tvm::Array<TypeParam>& arr) {
return MapDocify<TypeParam>(arr, [=](const TypeParam& tp) {
return Docify(tp);
});
}
std::vector<Doc> DocifyTypeConstraint(const tvm::Array<TypeConstraint>& arr) {
return MapDocify<TypeConstraint>(arr, [=](const TypeConstraint& tc) { return Docify(tc); });
}
Doc VisitType_(const TensorTypeNode* t) final {
return DocOfStr("tensor");
}
Doc VisitType_(const TypeParamNode* p) final {
auto tp = GetRef<TypeParam>(p);
if (map.count(tp) == 0) {
auto name =
DocOfStr(Mangle("tp", cnt(TypeParamName())) +
std::string(":")) +
KindDocify(p->kind);
map.insert(std::pair<TypeParam, Doc>(tp, name));
}
return map.at(tp);
}
Doc Quantify(const tvm::Array<TypeParam>& tp, const Doc& d) {
if (tp.size() == 0) {
return d;
}
return Seq("forall", DocifyTypeParam(tp), ",") + Sep() + d;
}
Doc Constraint(const tvm::Array<TypeConstraint>& tc, const Doc& d) {
if (tc.size() == 0) {
return d;
}
return Seq("(", DocifyTypeConstraint(tc), ") =>") + Sep() + d;
}
Doc VisitType_(const FuncTypeNode* f) final {
auto inner = Seq("<", DocifyTypeArray(f->arg_types), ">") + Sep() +
DocOfStr("->") + Sep() + Docify(f->ret_type);
return Group(Quantify(f->type_params,
Constraint(f->type_constraints, inner)));
}
Doc VisitType_(const TypeRelationNode* r) final {
return DocOfStr("Relation") + Seq("(", DocifyTypeArray(r->args), ")");
}
Doc VisitType_(const TupleTypeNode* t) final {
return Seq("<", DocifyTypeArray(t->fields), ">");
}
Doc VisitType_(const IncompleteTypeNode* i) final {
return DocOfStr("_");
}
public:
TypeDocifier(const Environment& env) : env(env) { }
Doc Docify(const Type& t) { return t.get() ? (*this)(t) : DocOfStr("_"); }
};
class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
Environment env;
Counter<std::string> cnt;
std::unordered_map<Var, std::string, NodeHash, NodeEqual> map;
TypeDocifier td;
std::string VarName(const Var& v) {
if (map.count(v) == 0) {
map.insert(std::pair<Var, std::string>(v, Mangle(v->name_hint, cnt(v->name_hint))));
}
return map.at(v);
}
Doc TypeAnnotation(const Doc& d, const Type& t) {
// test for t being null. probably shouldnt has null. should talk to jared.
if (!t.get() || t.as<IncompleteTypeNode>()) {
return d;
} else {
return d + DocOfStr(":") + td.Docify(t);
}
}
std::vector<Doc> DocifyExprArray(const tvm::Array<Expr>& arr) {
std::vector<Doc> vec;
for (size_t i = 0; i < arr.size(); ++i) {
vec.push_back(Docify(arr[i]));
}
return vec;
}
std::vector<Doc> DocifyParamArray(const tvm::Array<Var>& arr) {
std::vector<Doc> vec;
for (Var param : arr) {
vec.emplace_back(TypeAnnotation(DocOfStr(VarName(param)),
param->type_annotation));
}
return vec;
}
Doc VisitExpr_(const ConstantNode* c) final {
return DocOfStr("some_constant");
}
Doc VisitExpr_(const TupleNode* t) final {
return Seq("<", DocifyExprArray(t->fields), ">");
}
Doc VisitExpr_(const VarNode* v) final {
return DocOfStr(VarName(GetRef<Var>(v)));
}
Doc VisitExpr_(const GlobalVarNode* g) final {
return DocOfStr(g->name_hint);
}
Doc VisitExpr_(const FunctionNode* f) final {
return Group(TypeAnnotation(Seq("(", DocifyParamArray(f->params), ")"), f->ret_type) + Sep() +
DocOfStr("=>") + Sep() +
Block(indent, "{", Docify(f->body), "}"));
}
Doc VisitExpr_(const CallNode* c) final {
return Docify(c->op) + Seq("<", DocifyExprArray(c->args), ">");
}
Doc VisitExpr_(const LetNode* l) final {
return Group(DocOfStr("let") + Sep() +
TypeAnnotation(Docify(l->var), l->var->type_annotation) + Sep() +
DocOfStr("=") + Sep() + Docify(l->value) + DocOfStr(";") + Endl() +
Docify(l->body));
}
Doc VisitExpr_(const IfNode* i) final {
return Group(DocOfStr("if") + Sep() + Docify(i->cond) + Sep() +
Block(indent, "{", Docify(i->true_branch), "}") + Sep() +
DocOfStr("else") + Sep() +
Block(indent, "{", Docify(i->false_branch), "}"));
}
Doc VisitExpr_(const OpNode* o) final {
return DocOfStr(o->name);
}
Doc VisitExpr_(const TupleGetItemNode* g) final {
return Docify(g->tuple) + DocOfStr(std::string(".") + std::to_string(g->index));
}
public:
ExprDocifier(const Environment& env) : env(env), td(env) { }
Doc Docify(const Expr& e) { return (*this)(e); }
};
Doc DocOfExpr(const Environment& env, const Expr& expr) {
ExprDocifier d(env);
return d.Docify(expr);
}
Doc DocOfType(const Environment& env, const Type& expr) {
TypeDocifier d(env);
return d.Docify(expr);
}
RDoc ExprRDoc(const Environment& env, const Expr& expr) {
return Layout(DocOfExpr(env, expr));
}
RDoc TypeRDoc(const Environment& env, const Type& expr) {
return Layout(DocOfType(env, expr));
}
std::ostream & DebugPrint(const Environment& env, const Expr& e, std::ostream& os) {
return os << ExprRDoc(env, e);
}
std::ostream & DebugPrint(const Environment& env, const Type& t, std::ostream& os) {
return os << TypeRDoc(env, t);
}
std::string PrintExpr(const Environment& env, const Expr& e) {
std::stringstream ss;
ss << ExprRDoc(env, e);
return ss.str();
}
std::string PrintType(const Environment& env, const Type& t) {
std::stringstream ss;
ss << TypeRDoc(env, t);
return ss.str();
}
TVM_REGISTER_API("relay._expr._debug_print")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[1];
if (x.as<TypeNode>()) {
*ret = PrintType(args[0], Downcast<Type>(x));
} else {
*ret = PrintExpr(args[0], Downcast<Expr>(x));
}
});
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file doc.h
* \brief A pretty printer DSL for constructing (Doc) and formatting (RDoc) documents.
* It is based heavily on Philip Wadler's "A prettier printer."
* See https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf
* for more details.
*
* Since the original paper uses call by value for efficiency, everything doc function is maximally lazy.
* You can probably yank speed by doing strict analysis and removing some Lazy (if this is bottleneck).
*/
#ifndef TVM_RELAY_IR_DOC_H_
#define TVM_RELAY_IR_DOC_H_
#include <tvm/relay/error.h>
#include <unordered_map>
#include <utility>
#include <string>
#include <functional>
#include <vector>
#include <memory>
#include <ostream>
#include <map>
namespace tvm {
namespace relay {
/*! \brief A Document represent structured text.
* beside having unstructured string, it capture different ways to compose them -
* line break, space, indentation, representation choice.
*/
struct Doc;
/*! \brief RDoc represent rendered document.
* all the high level detail on the document, such as indentation, choice, has been removed.
* there is only one single, straight forward way to print it.
*/
struct RDoc;
//! \brief Empty document
inline Doc Nil();
//! \brief Concatenate two documents
inline Doc App(const Doc& l, const Doc& r);
//! \brief Indent a document
inline Doc Nest(size_t width, const Doc& doc);
//! \brief Lift string to a document
inline Doc DocOfStr(const std::string& text);
//! \brief New line
inline Doc Endl();
//! \brief Remove all line break from the Document.
inline Doc Flatten(const Doc& d);
/*! \brief Choose between two possible layouts.
* assume Flatten(l) == Flatten(r), and l need to be more compact.
*/
inline Doc Choose(const Doc& l, const Doc& r);
//! \brief Use a single line if possible
inline Doc Group(const Doc& d);
//! \brief print an RDoc
inline std::ostream& operator<<(std::ostream& os, const RDoc& rdoc);
/*! \brief Joins a vector of documents with a given separator document
* \example Join(["a", "b, "c"], ", ") => "a, b, c"
* \param vec the vector of documents
* \param sep the separator between documents
*/
inline Doc Join(const std::vector<Doc>& vec, const Doc& sep);
/*! \brief Creates an indented block.
* \param indent the indentation size
* \param open the opening string
* \param body the body of the block
* \param close the closing string
*/
inline Doc Block(size_t indent, const std::string& open,
const Doc& body, const std::string& close);
/*! \brief Creates a comma-separated sequence with opening and closing strings.
* \param open the opening string
* \param body the body of the Block
* \param close the closing string
*/
inline Doc Seq(const std::string& open,
const std::vector<Doc>& body, const std::string& close);
//! \brief Either a space or a new line
inline Doc Sep();
/*! \brief Layout a document to a given width
* \param d the document to render
* \param width the line width
*/
inline RDoc Layout(const Doc& d, size_t width = 80);
// end of API, start of implementation
template<typename T>
struct LazyNode {
mutable std::function<T()> thunk;
explicit LazyNode(const std::function<T()>& thunk) : thunk(thunk) { }
};
//! \brief denote a value that will be computed (at most once) on need.
template<typename T>
struct Lazy {
std::shared_ptr<LazyNode<T> > lazy_node;
explicit Lazy(const std::function<T()>& thunk) :
lazy_node(std::make_shared<LazyNode<T>>(thunk)) { }
explicit Lazy(const T& value) : Lazy([=]() { return value; }) { }
explicit Lazy(const Lazy<Lazy<T>>& thunk) : Lazy([=]() { return thunk.get().get(); }) { }
// calculate the result.
// memoize it by replacing the thunk with a constant function which immediate return.
T get() const {
T res = lazy_node->thunk();
lazy_node->thunk = [=]() { return res; };
return res;
}
template<typename R>
Lazy<R> map(const std::function<R(const T&)>& func) const {
Lazy<T> self(*this);
return Lazy<R>([=]() -> R { return func(self.get()); });
}
};
struct NilNode;
struct AppNode;
struct NestNode;
struct TextNode;
struct LineNode;
struct ChoiceNode;
/*! \brief The inner representation of Doc.
* a doc represent structured text,
* and can be rendered onto screen while keeping the structure.
*/
struct DocNode {
/* a docnode is a union of the below node.
* exactly one of them will be non null.
* their meaning is denoted by the construction function of the same name.
* so for example, the meaning of AppNode is exactly a node construct by App.
*/
std::shared_ptr<NilNode> nil;
std::shared_ptr<AppNode> app;
std::shared_ptr<NestNode> nest;
std::shared_ptr<TextNode> text; // construct by DocOfStr
std::shared_ptr<LineNode> line;
std::shared_ptr<ChoiceNode> choice;
DocNode(std::shared_ptr<NilNode> nil,
std::shared_ptr<AppNode> app,
std::shared_ptr<NestNode> nest,
std::shared_ptr<TextNode> text,
std::shared_ptr<LineNode> line,
std::shared_ptr<ChoiceNode> choice) :
nil(nil),
app(app),
nest(nest),
text(text),
line(line),
choice(choice) { }
};
struct Doc {
Lazy<DocNode> doc;
explicit Doc(const DocNode& ed) : doc(ed) { }
explicit Doc(const Lazy<Doc>& ldoc) :
doc(ldoc.map<Lazy<DocNode> >([](const Doc& d){ return d.doc; })) { }
Doc operator+(const Doc& r) const {
return App(*this, r);
}
template<typename T>
Lazy<T> Match(
const std::function<T()>& nilf,
const std::function<T(const Doc&, const Doc&)>& appf,
const std::function<T(size_t, const Doc&)>& nestf,
const std::function<T(const std::string&)>& textf,
const std::function<T()>& linef,
const std::function<T(const Doc&, const Doc&)>& choicef) const;
};
struct NilNode { };
struct AppNode {
Doc left, right;
AppNode(const Doc& left, const Doc& right) : left(left), right(right) { }
};
struct NestNode {
size_t space;
Doc doc;
NestNode(size_t space, const Doc& doc) : space(space), doc(doc) { }
};
struct TextNode {
std::string text;
explicit TextNode(const std::string& text) : text(text) { }
};
struct LineNode { };
struct ChoiceNode {
Doc left, right;
ChoiceNode(const Doc& left, const Doc& right) : left(left), right(right) { }
};
template<typename T>
Lazy<T> Doc::Match(
const std::function<T()>& nilf,
const std::function<T(const Doc&, const Doc&)>& appf,
const std::function<T(size_t, const Doc&)>& nestf,
const std::function<T(const std::string&)>& textf,
const std::function<T()>& linef,
const std::function<T(const Doc&, const Doc&)>& choicef) const {
return doc.map<T>([=](const DocNode& d) {
if (d.nil) {
return nilf();
} else if (d.app) {
return appf(d.app->left, d.app->right);
} else if (d.nest) {
return nestf(d.nest->space, d.nest->doc);
} else if (d.text) {
return textf(d.text->text);
} else if (d.line) {
return linef();
} else {
return choicef(d.choice->left, d.choice->right);
}
});
}
//! \brief Empty document
inline Doc Nil() {
return Doc(DocNode(std::make_shared<NilNode>(), nullptr, nullptr, nullptr, nullptr, nullptr));
}
//! \brief Concatenate two documents
inline Doc App(const Doc& l, const Doc& r) {
return Doc(DocNode(
nullptr,
std::make_shared<AppNode>(l, r),
nullptr,
nullptr,
nullptr,
nullptr));
}
//! \brief Indent a document
inline Doc Nest(size_t width, const Doc& doc) {
auto x = std::make_shared<NestNode>(width, doc);
return Doc(DocNode(
nullptr,
nullptr,
std::make_shared<NestNode>(width, doc),
nullptr,
nullptr,
nullptr));
}
//! \brief Lift string to a document
inline Doc DocOfStr(const std::string& text) {
return Doc(DocNode(nullptr, nullptr, nullptr,
std::make_shared<TextNode>(text), nullptr, nullptr));
}
//! \brief New line
inline Doc Endl() {
return Doc(DocNode(nullptr, nullptr, nullptr, nullptr, std::make_shared<LineNode>(), nullptr));
}
/*! \brief Choose between two possible layouts.
* assume Flatten(l) == Flatten(r), and l need to be more compact.
*/
inline Doc Choose(const Doc& l, const Doc& r) {
return Doc(DocNode(
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
std::make_shared<ChoiceNode>(l, r)));
}
//! \brief Remove new line from the whole document.
inline Doc Flatten(const Doc& d) {
return Doc(d.Match<Doc>(
[]() { return Nil(); },
[](const Doc& l, const Doc& r) { return Flatten(l) + Flatten(r); },
[](size_t space, const Doc& doc) { return Flatten(doc); },
[](const std::string& str) { return DocOfStr(str); },
[]() { return DocOfStr(" "); },
[](const Doc& l, const Doc& r) { return Flatten(l); }));
}
//! \brief Use a single line if possible
inline Doc Group(const Doc& d) {
return Choose(Flatten(d), d);
}
struct RNilNode;
struct RTextNode;
struct RLineNode;
struct RDocNode {
std::shared_ptr<RNilNode> rnil;
std::shared_ptr<RTextNode> rtext;
std::shared_ptr<RLineNode> rline;
RDocNode(const std::shared_ptr<RNilNode>& rnil,
const std::shared_ptr<RTextNode>& rtext,
const std::shared_ptr<RLineNode>& rline) :
rnil(rnil), rtext(rtext), rline(rline) { }
};
/*! \brief RDoc represent rendered document.
* all the high level detail on the document, such as indentation, alternative, has been removed.
* there is only one single, straight forward way to print it.
*/
struct RDoc {
Lazy<RDocNode> doc;
explicit RDoc(const RDocNode& d) : doc(d) { }
explicit RDoc(const Lazy<RDoc>& ldoc) :
doc(ldoc.map<Lazy<RDocNode>>([](const RDoc& d){ return d.doc; })) { }
template<typename T>
Lazy<T> Match(
const std::function<T()> &rnilf,
const std::function<T(const std::string&, const RDoc&)>& rtextf,
const std::function<T(size_t, const RDoc&)>& rlinef) const;
};
inline std::ostream& operator<<(std::ostream& os, const RDoc& rdoc) {
return *rdoc.Match<std::ostream*>(
[&]() { return & os; },
[&](const std::string& text, const RDoc& r) {
return & (os << text << r);
},
[&](size_t space, const RDoc& r) {
return & (os << std::endl << std::string(space, ' ') << r);
}).get();
}
struct RNilNode { };
struct RTextNode {
std::string text;
RDoc rest;
RTextNode(const std::string& text, const RDoc& rest) : text(text), rest(rest) { }
};
struct RLineNode {
size_t space;
RDoc rest;
RLineNode(size_t space, const RDoc& rest) : space(space), rest(rest) { }
};
//! \brief Empty RDoc
inline RDoc RNil() { return RDoc(RDocNode(std::make_shared<RNilNode>(), nullptr, nullptr)); }
//! \brief RDoc that begin with std::string
inline RDoc RText(const std::string& text, const RDoc& rest) {
return RDoc(RDocNode(nullptr, std::make_shared<RTextNode>(text, rest), nullptr));
}
//! \brief RDoc that begin with a new line, followed by space
inline RDoc RLine(size_t space, const RDoc& rest) {
return RDoc(RDocNode(nullptr, nullptr, std::make_shared<RLineNode>(space, rest)));
}
template<typename T>
Lazy<T> RDoc::Match(
const std::function<T()>& rnilf,
const std::function<T(const std::string&, const RDoc&)>& rtextf,
const std::function<T(size_t, const RDoc&)>& rlinef) const {
return doc.map<T>([=](const RDocNode& rdoc) {
if (rdoc.rnil) {
return rnilf();
} else if (rdoc.rtext) {
return rtextf(rdoc.rtext->text, rdoc.rtext->rest);
} else {
return rlinef(rdoc.rline->space, rdoc.rline->rest);
}
});
}
template<typename T>
struct List;
template<typename T>
struct EagerList {
const std::shared_ptr<std::pair<T, List<T>>> cons;
};
//! \brief lazy list
template<typename T>
struct List {
Lazy<EagerList<T> > l;
List() : l([]() { return EagerList<T>({nullptr}); }) { }
List(const T& t, const List<T>& l) :
l([=]() { return EagerList<T>({std::make_shared<std::pair<T, List<T>>>(t, l)}); }) { }
template<typename R>
Lazy<R> Match(const std::function<R()>& nullf,
const std::function<R(const T&, const List<T>&)>& consf) const {
return l.template map<R>([=](const EagerList<T>& l) {
if (l.cons) {
return consf(l.cons->first, l.cons->second);
} else {
return nullf();
}
});
}
};
//! \brief Does x fit into line of size w?
inline bool Fits(int w, const RDoc& x) {
return (w >= 0) && x.Match<bool>(
[]() { return true; },
[=](const std::string& s, const RDoc& x) { return Fits(w - s.size(), x); },
[](size_t space, const RDoc& x) { return true; }).get();
}
//! \brief Choose the one that fits best.
inline RDoc Better(size_t w, size_t k, const RDoc& x, const RDoc& y) {
return Fits(w-k, x) ? x : y;
}
typedef std::pair<size_t/*indent size*/, Doc> best_arg;
inline RDoc Best(size_t w/*wrap width*/, size_t k/*space used*/,
const List<best_arg>& l/*to be rendered*/) {
return RDoc(l.Match<RDoc>(
[]() { return RNil(); },
[=](const best_arg& p, const List<best_arg>& z) {
return RDoc(p.second.Match<RDoc>(
[=]() { return Best(w, k, z); },
[=](const Doc& x, const Doc& y) {
return Best(
w,
k,
List<best_arg>(best_arg(p.first, x), List<best_arg>(best_arg(p.first, y), z))); },
[=](size_t j, const Doc& x) {
return Best(w, k, List<best_arg>(best_arg(p.first + j, x), z)); },
[=](const std::string& text) { return RText(text, Best(w, k + text.size(), z)); },
[=]() { return RLine(p.first, Best(w, p.first, z)); },
[=](const Doc& x, const Doc& y) {
return Better(
w,
k,
Best(w, k, List<best_arg>(best_arg(p.first, x), z)),
Best(w, k, List<best_arg>(best_arg(p.first, y), z))); }));
}));
}
/*! \brief Joins a vector of documents with a given separator document
* \example Join(["a", "b, "c"], ", ") => "a, b, c"
* \param vec the vector of documents
* \param sep the separator between documents
*/
inline Doc Join(const std::vector<Doc>& vec, const Doc& sep) {
// https://www.safaribooksonline.com/library/view/c-cookbook/0596007612/ch04s09.html
Doc output = Nil();
for (auto p = vec.begin(); p != vec.end(); ++p) {
output = output + *p;
if (p != vec.end() - 1) {
output = output + sep;
}
}
return output;
}
/*! \brief Creates an indented block.
* \param indent the indentation size
* \param open the opening string
* \param body the body of the block
* \param close the closing string
*/
inline Doc Block(size_t indent, const std::string& open,
const Doc& body, const std::string& close) {
return DocOfStr(open) + Nest(indent, Endl() + body) + Endl() + DocOfStr(close);
}
/*! \brief Creates a comma-separated sequence with opening and closing strings.
* \param open the opening string
* \param body the body of the Block
* \param close the closing string
*/
inline Doc Seq(const std::string& open,
const std::vector<Doc>& body, const std::string& close) {
return Group(DocOfStr(open) +
Nest(open.size(), Join(body, DocOfStr(",") + Endl())) +
DocOfStr(close));
}
//! \brief Either a space or a new line
inline Doc Sep() {
return Choose(DocOfStr(" "), Endl());
}
/*! \brief Layout a document to a given width
* \param d the document to render
* \param width the line width
*/
inline RDoc Layout(const Doc& d, size_t width) {
return Best(width, 0, List<best_arg>(best_arg(0, d), List<best_arg>()));
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_IR_DOC_H_
...@@ -100,6 +100,8 @@ void EnvironmentNode::Merge(const Environment &env) { ...@@ -100,6 +100,8 @@ void EnvironmentNode::Merge(const Environment &env) {
} }
} }
TVM_REGISTER_NODE_TYPE(EnvironmentNode);
TVM_REGISTER_API("relay._make.Environment") TVM_REGISTER_API("relay._make.Environment")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EnvironmentNode::make(args[0]); *ret = EnvironmentNode::make(args[0]);
......
...@@ -17,6 +17,8 @@ Constant ConstantNode::make(runtime::NDArray data) { ...@@ -17,6 +17,8 @@ Constant ConstantNode::make(runtime::NDArray data) {
return Constant(n); return Constant(n);
} }
TVM_REGISTER_NODE_TYPE(ConstantNode);
TVM_REGISTER_API("relay._make.Constant") TVM_REGISTER_API("relay._make.Constant")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ConstantNode::make(args[0]); *ret = ConstantNode::make(args[0]);
...@@ -44,6 +46,8 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) { ...@@ -44,6 +46,8 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
return Tuple(n); return Tuple(n);
} }
TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_API("relay._make.Tuple") TVM_REGISTER_API("relay._make.Tuple")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TupleNode::make(args[0]); *ret = TupleNode::make(args[0]);
...@@ -61,6 +65,8 @@ Var VarNode::make(std::string name_hint, Type type_annotation) { ...@@ -61,6 +65,8 @@ Var VarNode::make(std::string name_hint, Type type_annotation) {
return Var(n); return Var(n);
} }
TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_API("relay._make.Var") TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = VarNode::make(args[0], args[1]); *ret = VarNode::make(args[0], args[1]);
...@@ -82,6 +88,8 @@ GlobalVar GlobalVarNode::make(std::string name_hint) { ...@@ -82,6 +88,8 @@ GlobalVar GlobalVarNode::make(std::string name_hint) {
return GlobalVar(n); return GlobalVar(n);
} }
TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_API("relay._make.GlobalVar") TVM_REGISTER_API("relay._make.GlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = GlobalVarNode::make(args[0]); *ret = GlobalVarNode::make(args[0]);
...@@ -94,13 +102,13 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -94,13 +102,13 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
Function FunctionNode::make(tvm::Array<Var> params, Function FunctionNode::make(tvm::Array<Var> params,
Type ret_type,
Expr body, Expr body,
Type ret_type,
tvm::Array<TypeParam> type_params) { tvm::Array<TypeParam> type_params) {
NodePtr<FunctionNode> n = make_node<FunctionNode>(); NodePtr<FunctionNode> n = make_node<FunctionNode>();
n->params = std::move(params); n->params = std::move(params);
n->ret_type = std::move(ret_type);
n->body = std::move(body); n->body = std::move(body);
n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params); n->type_params = std::move(type_params);
return Function(n); return Function(n);
} }
...@@ -113,6 +121,8 @@ FuncType FunctionNode::func_type_annotation() const { ...@@ -113,6 +121,8 @@ FuncType FunctionNode::func_type_annotation() const {
return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {}); return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {});
} }
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_API("relay._make.Function") TVM_REGISTER_API("relay._make.Function")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FunctionNode::make(args[0], args[1], args[2], args[3]); *ret = FunctionNode::make(args[0], args[1], args[2], args[3]);
...@@ -135,6 +145,8 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs, ...@@ -135,6 +145,8 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
return Call(n); return Call(n);
} }
TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_API("relay._make.Call") TVM_REGISTER_API("relay._make.Call")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CallNode::make(args[0], args[1], args[2], args[3]); *ret = CallNode::make(args[0], args[1], args[2], args[3]);
...@@ -154,6 +166,8 @@ Let LetNode::make(Var var, Expr value, Expr body) { ...@@ -154,6 +166,8 @@ Let LetNode::make(Var var, Expr value, Expr body) {
return Let(n); return Let(n);
} }
TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_API("relay._make.Let") TVM_REGISTER_API("relay._make.Let")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = LetNode::make(args[0], args[1], args[2]); *ret = LetNode::make(args[0], args[1], args[2]);
...@@ -173,6 +187,8 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { ...@@ -173,6 +187,8 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
return If(n); return If(n);
} }
TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) { TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IfNode::make(args[0], args[1], args[2]); *ret = IfNode::make(args[0], args[1], args[2]);
}); });
...@@ -190,6 +206,8 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) { ...@@ -190,6 +206,8 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
return TupleGetItem(n); return TupleGetItem(n);
} }
TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleGetItemNode::make(args[0], args[1]); *ret = TupleGetItemNode::make(args[0], args[1]);
}); });
......
...@@ -92,7 +92,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { ...@@ -92,7 +92,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
body.same_as(op->body)) { body.same_as(op->body)) {
return GetRef<Expr>(op); return GetRef<Expr>(op);
} else { } else {
return FunctionNode::make(params, ret_type, body, ty_params); return FunctionNode::make(params, body, ret_type, ty_params);
} }
} }
......
/*!
* Copyright (c) 2018 by Contributors
* \file text_printer.cc
* \brief Text printer to print relay in text form.
*/
#include <tvm/relay/environment.h>
#include <tvm/relay/expr_functor.h>
#include <sstream>
#include "../pass/type_functor.h"
#include "../../lang/attr_functor.h"
namespace tvm {
namespace relay {
/*!
* \brief the text value used in text printer.
* Defined as a struct for future compatibility reason
*/
struct TextValue {
/*! \brief The str representation */
std::string name;
// constructor
TextValue() {}
// constructor
explicit TextValue(std::string name) : name(name) {}
};
// operator overloading
inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NOLINT(*)
return os << val.name;
}
/*!
* \brief Meta data context for TextPrinter.
*
* This is an important part to enable bi-directional serializability.
* We use tvm's Node system to build the current IR.
* It can be hard to design a text format for all the possible nodes
* as the set of nodes can grow when we do more extensions.
*
* Instead of trying to design readable text format for every nodes,
* we support a meta-data section in the text format.
* We allow the text format to refer to a node in the meta-data section.
*
* The meta-data section is a json serialized string of an Array<NodeRef>.
* Each element in the meta-data section can be referenced by the text format.
* Each meta data node is printed in the following format.
*
* meta.<type-key-of-node>(<index-in-meta-section>)
*
* Specifically, consider the following IR(constructed by python).
*
* \code
*
* n = tvm.var("n")
* x = tvm.relay.var("x", shape=(n, 1))
* f = tvm.relay.Function([x], x)
* print(f.astext())
*
* \endcode
*
* The corresponding text format is shown in the following code block.
*
* \code
*
* function(%x: Tensor[(meta.Variable(id=0),), float32]) {
* %x
* }
* # Meta data section is a json-serialized string
* # of the following array.
* # [tvm.var("n")]
*
* \endcode
*
* Note that we store tvm.var("n") in the meta data section.
* Since it is stored in the index-0 in the meta-data seciton,
* we print it as meta.Variable(0).
*
* The text parser can recover this object by loading from the corresponding
* location in the meta data section.
*
* This is is a design trade-off.
* It allows us to embedded any meta-data in the text format,
* while still being able to tweak the text part of the printed IR easily.
*/
class TextMetaDataContext {
public:
/*!
* \brief Get text representation of meta node.
* \param node The node to be converted to meta node.
* \return A string representation of the meta node.
*/
std::string GetMetaNode(const NodeRef& node) {
std::ostringstream os;
auto it = meta_index_.find(node);
int64_t index;
if (it != meta_index_.end()) {
index = it->second;
} else {
index = static_cast<int64_t>(meta_data_.size());
meta_data_.push_back(node);
meta_index_[node] = index;
}
os << "meta." << node->type_key() << "(id=" << index << ")";
return os.str();
}
/*!
* \brief Get the metadata section in json format.
* \return the meta datastring.
*/
std::string GetMetaSection() const {
if (meta_data_.size() == 0) return std::string();
return SaveJSON(Array<NodeRef>(meta_data_));
}
private:
/*! \brief additional metadata stored in TVM json format */
std::vector<NodeRef> meta_data_;
/*! \brief map from meta data into its index */
std::unordered_map<NodeRef, int64_t, NodeHash, NodeEqual> meta_index_;
};
class TextPrinter :
public ExprFunctor<TextValue(const Expr&)> ,
public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*)
public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*)
public:
/*!
* \brief Print a node to string.
* \param node.
* \return The string representation.
*/
std::string Print(const NodeRef& node) {
if (node.as<FunctionNode>()) {
this->PrintFunc(Downcast<Function>(node));
} else if (node.as<EnvironmentNode>()) {
this->PrintEnv(Downcast<Environment>(node));
} else if (node.as_derived<TypeNode>()) {
this->PrintType(Downcast<Type>(node), stream_);
} else if (node.as_derived<ExprNode>()) {
this->PrintExpr(Downcast<Expr>(node));
} else {
stream_ << node;
}
std::string meta_json = meta_.GetMetaSection();
if (meta_json.length() != 0) {
// append meta data in the end.
stream_ << "# meta data\n"
<< "r\"\"\"\n"
<< meta_json << "\n"
<< "\"\"\"";
}
return stream_.str();
}
void PrintFunc(const Function& func) {
this->PrintFuncInternal("function", func);
stream_ << "\n";
}
void PrintEnv(const Environment& env) {
int counter = 0;
for (const auto& kv : env->functions) {
std::ostringstream os;
if (counter++ != 0) {
stream_ << "\n";
}
os << "def @" << kv.first->name_hint;
this->PrintFuncInternal(os.str(), kv.second);
stream_ << "\n";
}
}
void PrintExpr(const Expr& expr) {
TextValue val = GetValue(expr);
stream_ << val << "\n";
}
/*!
* \brief Get text representation of expr.
*
* This function may generate additional instructions
* in order to compute the final result id of expr.
*
* When trying to recursively print out an Expr.
* The caller should always call GetValue of its children first.
* Then the caller can print out to stream_ using the obtained value.
*
* This is to avoid the call of subsequent GetValue print out
* additional instructions which get mixed with the partial instruction
* printed by the caller.
*
* \param expr The input expression.
* \return The text value of Expr.
*/
TextValue GetValue(const Expr& expr) {
auto it = memo_.find(expr);
if (it != memo_.end()) return it->second;
TextValue val = this->VisitExpr(expr);
memo_[expr] = val;
return val;
}
//------------------------------------
// Overload of Expr printing functions
//------------------------------------
TextValue VisitExpr_(const ConstantNode* op) final {
// Print out simple scalar directly.
if (op->is_scalar()) {
std::ostringstream os;
DataType dtype = TVMType2Type(op->data->dtype);
CHECK_EQ(op->data->ctx.device_type, kDLCPU);
if (dtype == Int(32)) {
return ConstScalar(dtype, static_cast<const int32_t*>(op->data->data));
} else if (dtype == Int(64)) {
return ConstScalar(dtype, static_cast<const int64_t*>(op->data->data));
} else if (dtype == Float(32)) {
return ConstScalar(dtype, static_cast<const float*>(op->data->data));
} else if (dtype == Float(64)) {
return ConstScalar(dtype, static_cast<const double*>(op->data->data));
}
}
// default fall-back, record it as meta node.
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = " << meta_.GetMetaNode(GetRef<NodeRef>(op));
this->PrintEndInst("\n");
return id;
}
TextValue VisitExpr_(const TupleNode* op) final {
std::vector<TextValue> fields;
for (Expr field : op->fields) {
fields.push_back(GetValue(field));
}
// NOTE: always recursively visit to get ids,
// before print out the current line
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = (";
for (size_t i = 0; i < fields.size(); ++i) {
stream_ << fields[i];
if (i + 1 != fields.size()) {
stream_ << ", ";
}
}
stream_ << ')';
this->PrintEndInst("\n");
return id;
}
TextValue VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
// This is an unbounded var.
TextValue val = AllocVarName(var);
this->PrintIndent();
stream_ << "free_var ";
this->PrintVarDecl(var, stream_);
this->PrintEndInst("\n");
return val;
}
TextValue VisitExpr_(const GlobalVarNode* op) final {
return TextValue('@' + op->name_hint);
}
TextValue VisitExpr_(const FunctionNode* op) final {
TextValue id = AllocTempVar();
std::ostringstream os;
os << id << " = function";
this->PrintFuncInternal(os.str(), GetRef<Function>(op));
this->PrintEndInst("\n");
return id;
}
TextValue VisitExpr_(const CallNode* op) final {
// TODO(tqchen, M.K.): support generic call
// possibly through meta-data
CHECK_EQ(op->type_args.size(), 0U)
<< "generic call not yet supported";
TextValue call_op = GetValue(op->op);
std::vector<TextValue> args;
for (Expr arg : op->args) {
args.emplace_back(GetValue(arg));
}
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = " << call_op << "(";
for (size_t i = 0; i < args.size(); ++i) {
stream_ << args[i];
if (i + 1 != args.size()) {
stream_ << ", ";
}
}
this->PrintCallAttrs(op->op, op->attrs, stream_);
stream_ << ")";
this->PrintEndInst("");
this->PrintOptionalInfo(GetRef<Expr>(op));
stream_ << '\n';
return id;
}
TextValue VisitExpr_(const LetNode* op) final {
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = ";
this->PrintScope(GetRef<Expr>(op));
this->PrintEndInst("\n");
return id;
}
TextValue VisitExpr_(const IfNode* op) final {
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = ";
this->PrintScope(GetRef<Expr>(op));
this->PrintEndInst("\n");
return id;
}
TextValue VisitExpr_(const OpNode* op) final {
return TextValue(op->name);
}
TextValue VisitExpr_(const TupleGetItemNode* op) final {
TextValue tuple = GetValue(op->tuple);
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = " << tuple << "[" << op->index << "]";
this->PrintEndInst("\n");
return id;
}
/*!
* \brief Print the type to os
* \param type The type to be printed.
* \param os The output type.
*/
void PrintType(const Type& type, std::ostream& os) { // NOLINT(*)
this->VisitType(type, os);
}
//------------------------------------
// Overload of Expr printing functions
//------------------------------------
void VisitType_(const TensorTypeNode* node, std::ostream& os) final { // NOLINT(*)
// scalar type
if (node->shape.size() == 0) {
os << runtime::TVMType2String(Type2TVMType(node->dtype));
return;
}
os << "Tensor[(";
for (size_t i = 0; i < node->shape.size(); ++i) {
this->PrintAttr(node->shape[i], os);
if (i + 1 != node->shape.size()) {
os << ", ";
}
}
// conform to python tuple format (1,)
if (node->shape.size() == 1) {
os << ",";
}
os << "), " << runtime::TVMType2String(Type2TVMType(node->dtype)) << "]";
}
void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*)
// by default always print as meta-data
os << meta_.GetMetaNode(GetRef<NodeRef>(node));
}
/*!
* \brief Print an attribute value to os.
* \param value The value to be printed.
* \param os The output type.
*/
void PrintAttr(const NodeRef& value, std::ostream& os) { // NOLINT(*)
this->VisitAttr(value, os);
}
//------------------------------------
// Overload of Attr printing functions
//------------------------------------
void VisitAttr_(const ArrayNode* op, std::ostream& os) final { // NOLINT(*)
os << "[";
for (size_t i = 0; i < op->data.size(); ++i) {
this->PrintAttr(NodeRef(op->data[i]), os);
if (i + 1 != op->data.size()) {
os << ", ";
}
}
os << "]";
}
void VisitAttrDefault_(const Node* op, std::ostream& os) final { // NOLINT(*)
os << meta_.GetMetaNode(GetRef<NodeRef>(op));
}
void VisitAttr_(const ir::IntImm* op, std::ostream& os) final { // NOLINT(*)
this->PrintConstScalar(op->type, &(op->value), os);
}
void VisitAttr_(const ir::UIntImm* op, std::ostream& os) final { // NOLINT(*)
this->PrintConstScalar(op->type, &(op->value), os);
}
void VisitAttr_(const ir::FloatImm* op, std::ostream& os) final { // NOLINT(*)
this->PrintConstScalar(op->type, &(op->value), os);
}
void VisitAttr_(const ir::StringImm* op, std::ostream& os) final { // NOLINT(*)
this->PrintString(op->value, os);
}
protected:
/*!
* \brief Print attributes after call.
* \param op The operator to be called.
* \param attrs The attributes.
* \param os The output stream.
*/
void PrintCallAttrs(const Expr& op, const Attrs& attrs, std::ostream& os); // NOLINT(*)
/*!
* \brief Print the a new scopr.
* \param body The body.
*/
void PrintScope(Expr body) {
stream_ << "{\n";
int sid = this->BeginScope();
this->PrintScopeBody(body);
this->EndScope(sid);
this->PrintIndent();
stream_ << "}";
}
/*!
* \brief Print the body of a new scope without {}
*
* This function will keep printing continuous sequence
* of let/if scope without introducing a new scope in the text.
*
* \param body The body.
*/
void PrintScopeBody(Expr body) {
if (const LetNode* let = body.as<LetNode>()) {
TextValue value = GetValue(let->value);
AllocVarName(let->var);
// let var = value;
this->PrintIndent();
stream_ << "let ";
this->PrintVarDecl(let->var, stream_);
stream_ << " = " << value;
this->PrintEndInst("\n");
this->PrintScopeBody(let->body);
} else if (const IfNode* ifnode = body.as<IfNode>()) {
TextValue cond = GetValue(ifnode->cond);
this->PrintIndent();
stream_ << "if (" << cond << ") ";
this->PrintScope(ifnode->true_branch);
this->PrintIndent();
stream_ << "else ";
this->PrintScope(ifnode->false_branch);
this->PrintEndInst("\n");
} else {
TextValue value = GetValue(body);
this->PrintIndent();
stream_ << value;
this->PrintEndInst("\n");
}
}
/*!
* \brief Internal function to print a function argument list and its body.
* \param prefix The prefix before argument list.
* \param fn The function to be printed.
*/
void PrintFuncInternal(std::string prefix, const Function& fn) {
// TODO(tqchen, M.K.) support generic function
// Possibly through meta-data
CHECK_EQ(fn->type_params.size(), 0U)
<< "generic fn not yet supported";
this->PrintIndent();
stream_ << prefix << "(";
size_t decl_indent = prefix.length() + 1;
for (size_t i = 0; i < fn->params.size(); ++i) {
if (i != 0) {
this->PrintIndent(decl_indent);
}
AllocVarName(fn->params[i]);
this->PrintVarDecl(fn->params[i], stream_);
if (i + 1 != fn->params.size()) {
stream_ << ",\n";
}
}
stream_ << ") ";
if (fn->ret_type.defined()) {
stream_ << " -> ";
this->PrintType(fn->ret_type, stream_);
}
this->PrintScope(fn->body);
}
/*!
* \brief Print additional info about expr in comment.
* \param expr The expression.
*/
void PrintOptionalInfo(const Expr& expr) {
// additional information in comment.
if (expr->checked_type_.defined()) {
stream_ << " # ty=";
this->PrintType(expr->checked_type(), stream_);
}
}
/*!
* \brief print var_name[:type]
* \param var The variable to be printed
* \param os The output stream
*/
void PrintVarDecl(const Var& var, std::ostream& os) { // NOLINT(*)
TextValue v = GetValue(var);
os << v;
if (var->type_annotation.defined()) {
os << ": ";
this->PrintType(var->type_annotation, os);
}
}
/*!
* \brief Get a constant scalar value.
* \param dtype The data type.
* \param data The pointer to the data.
* \tparam T the content data type holding the data.
*/
template<typename T>
TextValue ConstScalar(DataType dtype, const T* data) {
std::ostringstream os;
PrintConstScalar(dtype, data, os);
return TextValue(os.str());
}
/*!
* \brief special method to print out const scalar
* \param dtype The data type
* \param data The pointer to hold the data.
* \param os The output stream.
*/
template<typename T>
void PrintConstScalar(DataType dtype, const T* data, std::ostream& os) { // NOLINT(*)
if (dtype == Int(32)) {
os << data[0];
} else if (dtype == Float(32)) {
os << data[0] << 'f';
} else if (dtype == Bool()) {
PrintBool(data[0] != 0, os);
} else {
os << dtype << "(" << data[0] << ")";
}
}
/*!
* \brief Print constant bool value.
* \param value The value to be printed.
* \param os The output stream
*/
void PrintBool(bool value, std::ostream& os) { // NOLINT(*)
if (value) {
os << "True";
} else {
os << "False";
}
}
/*!
* \brief Print constant string.
* \param value The value to be printed.
* \param os The output stream
*/
void PrintString(const std::string& value, std::ostream& os) { // NOLINT(*)
// TODO(M.K.): add escape.
os << "\"" << value << "\"";
}
/*!
* \brief get a unique name with the corresponding prefix
* \param prefix The prefix of the name
* \return The returned name.
*/
std::string GetUniqueName(std::string prefix) {
auto it = name_alloc_map_.find(prefix);
if (it != name_alloc_map_.end()) {
while (true) {
std::ostringstream os;
os << prefix << (++it->second);
std::string name = os.str();
if (name_alloc_map_.count(name) == 0) {
prefix = name;
break;
}
}
}
name_alloc_map_[prefix] = 0;
return prefix;
}
/*!
* \brief mark the beginning of a new scope
* \return The scope id.
*/
int BeginScope() {
int sid = static_cast<int>(scope_valid_.size());
scope_valid_.push_back(true);
indent_ += 2;
return sid;
}
/*!
* \brief mark the end of an old scope.
* \param scope_id The scope id to be ended.
*/
void EndScope(int scope_id) {
scope_valid_[scope_id] = false;
indent_ -= 2;
}
/*!
* \brief Print the indent to the stream.
* \param more_indent More indentation besides the current one.
*/
void PrintIndent(int64_t more_indent = 0) {
for (int i = 0; i < indent_ + more_indent; ++i) {
stream_ << ' ';
}
}
/*!
* \brief print end of the line.
*/
void PrintEndInst(const char* suffix) {
stream_ << suffix;
}
/*!
* \brief Allocate temporary value
* \return A new text value.
*/
TextValue AllocTempVar() {
std::ostringstream os;
os << '%' << temp_var_counter_++;
return TextValue(os.str());
}
/*!
* \brief Allocate name to a variable.
* \param var The input variable.
* \return The corresponding name.
*/
TextValue AllocVarName(const Var& var) {
std::string name = GetUniqueName('%' + var->name_hint);
TextValue val(name);
CHECK(!memo_.count(var));
memo_[var] = val;
return val;
}
private:
class AttrPrinter;
friend class AttrPrinter;
/*! \brief meta data context */
TextMetaDataContext meta_;
/*! \brief Check whether scope is still valid */
std::vector<bool> scope_valid_;
/*! \brief The current indentation value */
int indent_{0};
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief Map from expression to its text value */
std::unordered_map<Expr, TextValue, NodeHash, NodeEqual> memo_;
/*! \brief counter of temporary variable */
int64_t temp_var_counter_{0};
/*! \brief Output stream */
std::ostringstream stream_;
};
/*!
* \brief Attribute printer which prints the attributes in the call.
*/
class TextPrinter::AttrPrinter: public AttrVisitor {
public:
AttrPrinter(std::ostream& stream, TextPrinter* parent) // NOLINT(*)
: stream_(stream), parent_(parent) {}
void Visit(const char* key, double* value) final {
PrintSep();
stream_ << key << "=" << value[0];
}
void Visit(const char* key, int64_t* value) final {
PrintSep();
stream_ << key << "=" << value[0];
}
void Visit(const char* key, uint64_t* value) final {
PrintSep();
stream_ << key << "=" << value[0];
}
void Visit(const char* key, int* value) final {
PrintSep();
stream_ << key << "=" << value[0];
}
void Visit(const char* key, bool* value) final {
PrintSep();
stream_ << key << "=";
parent_->PrintBool(value[0], stream_);
}
void Visit(const char* key, std::string* value) final {
PrintSep();
stream_ << key << "=";
parent_->PrintString(value[0], stream_);
}
void Visit(const char* key, void** value) final {
LOG(FATAL) << "do not allow void as argument";
}
void Visit(const char* key, DataType* value) final {
PrintSep();
stream_ << key << "=";
parent_->PrintString(runtime::TVMType2String(Type2TVMType(value[0])), stream_);
}
void Visit(const char* key, NodeRef* value) final {
PrintSep();
stream_ << key << "=";
parent_->PrintAttr(value[0], stream_);
}
void Visit(const char* key, runtime::NDArray* value) final {
LOG(FATAL) << "do not allow NDarray as argument";
}
private:
void PrintSep() {
stream_ << ", ";
}
std::ostream& stream_; // NOLINT(*)
TextPrinter* parent_;
};
void TextPrinter::PrintCallAttrs(const Expr& op,
const Attrs& attrs,
std::ostream& os) { // NOLINT(*)
if (!attrs.defined()) return;
if (const auto* op_node = op.as<OpNode>()) {
if (attrs->type_index() == op_node->attrs_type_index) {
AttrPrinter printer(os, this);
const_cast<BaseAttrsNode*>(attrs.operator->())
->VisitNonDefaultAttrs(&printer);
return;
}
}
os << ", " << meta_.GetMetaNode(attrs);
}
std::string RelayPrint(const NodeRef& node) {
return TextPrinter().Print(node);
}
TVM_REGISTER_API("relay._expr._text_print")
.set_body_typed<std::string(const NodeRef&)>(RelayPrint);
} // namespace relay
} // namespace tvm
...@@ -22,6 +22,8 @@ TensorType TensorTypeNode::Scalar(DataType dtype) { ...@@ -22,6 +22,8 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
return TensorTypeNode::make({}, dtype); return TensorTypeNode::make({}, dtype);
} }
TVM_REGISTER_NODE_TYPE(TensorTypeNode);
TVM_REGISTER_API("relay._make.TensorType") TVM_REGISTER_API("relay._make.TensorType")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Array<IndexExpr> shape = args[0]; Array<IndexExpr> shape = args[0];
...@@ -30,8 +32,8 @@ TVM_REGISTER_API("relay._make.TensorType") ...@@ -30,8 +32,8 @@ TVM_REGISTER_API("relay._make.TensorType")
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const TensorTypeNode *node, .set_dispatch<TensorTypeNode>([](const TensorTypeNode *node,
tvm::IRPrinter *p) { tvm::IRPrinter *p) {
p->stream << "TensorTypeNode(" << node->dtype << ", " << node->shape << ")"; p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
}); });
TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) { TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) {
...@@ -41,6 +43,8 @@ TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) { ...@@ -41,6 +43,8 @@ TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) {
return TypeParam(n); return TypeParam(n);
} }
TVM_REGISTER_NODE_TYPE(TypeParamNode);
TVM_REGISTER_API("relay._make.TypeParam") TVM_REGISTER_API("relay._make.TypeParam")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
int kind = args[1]; int kind = args[1];
...@@ -61,6 +65,8 @@ IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) { ...@@ -61,6 +65,8 @@ IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) {
return IncompleteType(n); return IncompleteType(n);
} }
TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_API("relay._make.IncompleteType") TVM_REGISTER_API("relay._make.IncompleteType")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[0]; int kind = args[0];
...@@ -86,6 +92,8 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types, ...@@ -86,6 +92,8 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
return FuncType(n); return FuncType(n);
} }
TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_API("relay._make.FuncType") TVM_REGISTER_API("relay._make.FuncType")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]); *ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]);
...@@ -111,6 +119,8 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func, ...@@ -111,6 +119,8 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func,
return TypeRelation(n); return TypeRelation(n);
} }
TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_API("relay._make.TypeRelation") TVM_REGISTER_API("relay._make.TypeRelation")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]); *ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]);
...@@ -129,6 +139,8 @@ TupleType TupleTypeNode::make(Array<Type> fields) { ...@@ -129,6 +139,8 @@ TupleType TupleTypeNode::make(Array<Type> fields) {
return TupleType(n); return TupleType(n);
} }
TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_API("relay._make.TupleType") TVM_REGISTER_API("relay._make.TupleType")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TupleTypeNode::make(args[0]); *ret = TupleTypeNode::make(args[0]);
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
TVM_REGISTER_NODE_TYPE(ConvAttrs); TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
bool Conv2DRel(const Array<Type>& types, bool Conv2DRel(const Array<Type>& types,
int num_inputs, int num_inputs,
...@@ -25,7 +25,7 @@ bool Conv2DRel(const Array<Type>& types, ...@@ -25,7 +25,7 @@ bool Conv2DRel(const Array<Type>& types,
static const Layout kNCHW("NCHW"); static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW"); static const Layout kOIHW("OIHW");
const ConvAttrs* param = attrs.as<ConvAttrs>(); const Conv2DAttrs* param = attrs.as<Conv2DAttrs>();
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->data_layout); const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->weight_layout); const Layout kernel_layout(param->weight_layout);
...@@ -113,7 +113,7 @@ Expr MakeConv2D(Expr data, ...@@ -113,7 +113,7 @@ Expr MakeConv2D(Expr data,
std::string weight_layout, std::string weight_layout,
std::string out_layout, std::string out_layout,
DataType out_dtype) { DataType out_dtype) {
auto attrs = make_node<ConvAttrs>(); auto attrs = make_node<Conv2DAttrs>();
attrs->strides = std::move(strides); attrs->strides = std::move(strides);
attrs->padding = std::move(padding); attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation); attrs->dilation = std::move(dilation);
...@@ -148,6 +148,7 @@ with the layer input to produce a tensor of outputs. ...@@ -148,6 +148,7 @@ with the layer input to produce a tensor of outputs.
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`. (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.Conv2DAttrs")
.set_num_inputs(2) .set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.") .add_argument("weight", "Tensor", "The weight tensor.")
...@@ -296,6 +297,7 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW` ...@@ -296,6 +297,7 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1] out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1]
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.Conv2DTransposeAttrs")
.set_num_inputs(2) .set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.") .add_argument("weight", "Tensor", "The weight tensor.")
......
...@@ -78,6 +78,7 @@ RELAY_REGISTER_OP("nn.dense") ...@@ -78,6 +78,7 @@ RELAY_REGISTER_OP("nn.dense")
- **out**: `(x1, x2, ..., xn, units)`. - **out**: `(x1, x2, ..., xn, units)`.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.DenseAttrs")
.set_num_inputs(2) .set_num_inputs(2)
.add_argument("data", "nD Tensor", "Input data.") .add_argument("data", "nD Tensor", "Input data.")
.add_argument("weight", "2D Tensor", "Weight matrix.") .add_argument("weight", "2D Tensor", "Weight matrix.")
...@@ -107,6 +108,7 @@ RELAY_REGISTER_OP("nn.leaky_relu") ...@@ -107,6 +108,7 @@ RELAY_REGISTER_OP("nn.leaky_relu")
`y = x > 0 ? x : alpha * x` `y = x > 0 ? x : alpha * x`
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.LeakyReluAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "Input data.") .add_argument("data", "Tensor", "Input data.")
.set_support_level(3) .set_support_level(3)
...@@ -135,6 +137,7 @@ RELAY_REGISTER_OP("nn.softmax") ...@@ -135,6 +137,7 @@ RELAY_REGISTER_OP("nn.softmax")
- **data**: The input data - **data**: The input data
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SoftmaxAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1) .set_support_level(1)
...@@ -163,6 +166,7 @@ RELAY_REGISTER_OP("nn.log_softmax") ...@@ -163,6 +166,7 @@ RELAY_REGISTER_OP("nn.log_softmax")
- **data**: The input data - **data**: The input data
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SoftmaxAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1) .set_support_level(1)
...@@ -171,9 +175,9 @@ RELAY_REGISTER_OP("nn.log_softmax") ...@@ -171,9 +175,9 @@ RELAY_REGISTER_OP("nn.log_softmax")
// BatchFlatten // BatchFlatten
bool BatchFlattenRel(const Array<Type>& types, bool BatchFlattenRel(const Array<Type>& types,
int num_inputs, int num_inputs,
const Attrs& attrs, const Attrs& attrs,
const TypeReporter& reporter) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2); CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>(); const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false; if (data == nullptr) return false;
...@@ -278,6 +282,7 @@ centered at that value (zero padding is added where necessary). ...@@ -278,6 +282,7 @@ centered at that value (zero padding is added where necessary).
- **data**: The input tensor. - **data**: The input tensor.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.LRNAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2) .set_support_level(2)
...@@ -296,12 +301,12 @@ Expr MakeL2Normalize(Expr data, ...@@ -296,12 +301,12 @@ Expr MakeL2Normalize(Expr data,
} }
TVM_REGISTER_API("relay.op.nn._make.l2_normalize") TVM_REGISTER_API("relay.op.nn._make.l2_normalize")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeL2Normalize, args, rv); runtime::detail::unpack_call<Expr, 3>(MakeL2Normalize, args, rv);
}); });
RELAY_REGISTER_OP("nn.l2_normalize") RELAY_REGISTER_OP("nn.l2_normalize")
.describe(R"code(L2 Normalization layer. .describe(R"code(L2 Normalization layer.
Normalizes along dimension axis using an L2 norm Normalizes along dimension axis using an L2 norm
...@@ -352,6 +357,7 @@ During training, each element of the input is set to zero with probability ``p`` ...@@ -352,6 +357,7 @@ During training, each element of the input is set to zero with probability ``p``
The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input unchanged. The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input unchanged.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.DropoutAttrs")
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("data", "Tensor", "Input to which dropout will be applied.") .add_argument("data", "Tensor", "Input to which dropout will be applied.")
.set_support_level(1) .set_support_level(1)
...@@ -478,6 +484,7 @@ axis to be the last item in the input shape. ...@@ -478,6 +484,7 @@ axis to be the last item in the input shape.
.. note:: .. note::
This operator can be optimized away for inference. This operator can be optimized away for inference.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.BatchNormAttrs")
.set_num_inputs(5) .set_num_inputs(5)
.add_argument("data", "Tensor", "Input to which batch_norm will be applied.") .add_argument("data", "Tensor", "Input to which batch_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.") .add_argument("gamma", "Tensor", "The gamma scale factor.")
......
...@@ -60,7 +60,7 @@ bool PadRel(const Array<Type>& types, ...@@ -60,7 +60,7 @@ bool PadRel(const Array<Type>& types,
} }
// Handler to create a call to the padding op used by front-end FFI // Handler to create a call to the padding op used by front-end FFI
Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) { Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) {
auto attrs = make_node<PadAttrs>(); auto attrs = make_node<PadAttrs>();
attrs->pad_value = pad_value; attrs->pad_value = pad_value;
attrs->pad_width = std::move(pad_width); attrs->pad_width = std::move(pad_width);
......
...@@ -76,6 +76,7 @@ RELAY_REGISTER_OP("expand_dims") ...@@ -76,6 +76,7 @@ RELAY_REGISTER_OP("expand_dims")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_num_inputs(1) .set_num_inputs(1)
.set_attrs_type_key("relay.attrs.ExpandDimsAttrs")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1) .set_support_level(1)
.add_type_rel("ExpandDims", ExpandDimsRel); .add_type_rel("ExpandDims", ExpandDimsRel);
...@@ -481,6 +482,7 @@ RELAY_REGISTER_OP("zeros") ...@@ -481,6 +482,7 @@ RELAY_REGISTER_OP("zeros")
.describe(R"code(Fill array with zeros. .describe(R"code(Fill array with zeros.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.InitOpAttrs")
.set_num_inputs(0) .set_num_inputs(0)
.set_support_level(3) .set_support_level(3)
.add_type_rel("InitOp", InitOpRel); .add_type_rel("InitOp", InitOpRel);
...@@ -503,6 +505,7 @@ RELAY_REGISTER_OP("ones") ...@@ -503,6 +505,7 @@ RELAY_REGISTER_OP("ones")
.describe(R"code(Fill array with ones. .describe(R"code(Fill array with ones.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.InitOpAttrs")
.set_num_inputs(0) .set_num_inputs(0)
.set_support_level(3) .set_support_level(3)
.add_type_rel("InitOp", InitOpRel); .add_type_rel("InitOp", InitOpRel);
...@@ -697,6 +700,7 @@ RELAY_REGISTER_OP("squeeze") ...@@ -697,6 +700,7 @@ RELAY_REGISTER_OP("squeeze")
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
.set_num_inputs(1) .set_num_inputs(1)
.set_attrs_type_key("relay.attrs.SqueezeAttrs")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3) .set_support_level(3)
.add_type_rel("Squeeze", SqueezeRel); .add_type_rel("Squeeze", SqueezeRel);
......
...@@ -74,7 +74,10 @@ class CalcDep : private ExprMutator { ...@@ -74,7 +74,10 @@ class CalcDep : private ExprMutator {
} }
Expr VisitExpr_(const FunctionNode* f) final { Expr VisitExpr_(const FunctionNode* f) final {
return FunctionNode::make(f->params, f->ret_type, Eliminate(f->body), f->type_params); return FunctionNode::make(f->params,
Eliminate(f->body),
f->ret_type,
f->type_params);
} }
// generate the let list from dependency graph // generate the let list from dependency graph
......
...@@ -20,6 +20,7 @@ class TypeFunctor; ...@@ -20,6 +20,7 @@ class TypeFunctor;
#define TYPE_FUNCTOR_DEFAULT \ #define TYPE_FUNCTOR_DEFAULT \
{ return VisitTypeDefault_(op, std::forward<Args>(args)...); } { return VisitTypeDefault_(op, std::forward<Args>(args)...); }
#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \ #define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \ vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \ [](const NodeRef& n, TSelf* self, Args... args) { \
......
import tvm
from tvm import relay
from tvm.relay.expr import debug_print
from tvm.relay.ir_builder import IRBuilder
ib = IRBuilder()
def show(e):
r = debug_print(ib.env, e)
assert r is not None
def test_constant():
arr = tvm.nd.array(10)
const = relay.Constant(arr)
show(const)
# should print the array inside?
def test_tuple():
fields = tvm.convert([])
tup = relay.Tuple(fields)
show(tup)
def test_local_var():
name_hint = 's'
lv = relay.Var(name_hint)
show(lv)
def test_dup_var():
lv = relay.Var('s')
rv = relay.Var('s')
show(relay.Tuple([lv, rv]))
def test_large_dup_var():
av = relay.Var('s')
bv = relay.Var('s')
cv = relay.Var('s')
show(relay.Tuple([av, bv, cv]))
def test_global_var():
name_hint = 'g'
gv = relay.GlobalVar(name_hint)
gv.name_hint == name_hint
show(gv)
def test_function():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Var(n) for n in param_names])
ret_type = None
body = params[0]
type_params = tvm.convert([])
fn = relay.Function(params, ret_type, body, type_params)
show(fn)
def test_call():
op = relay.Var('f')
arg_names = ['a', 'b', 'c', 'd']
args = tvm.convert([relay.Var(n) for n in arg_names])
call = relay.Call(op, args, None, None)
show(call)
def test_let():
ty = relay.ty.TensorType((10, 20), 'float32')
lv = relay.Var('x', ty)
arr = tvm.nd.array(10)
value = relay.Constant(arr)
let = relay.Let(lv, value, lv)
show(let)
def test_if():
cond = relay.Var('cond')
left = relay.Var('left')
right = relay.Var('right')
ife = relay.If(cond, left, right)
show(ife)
def test_tuple_get_item():
t = relay.Var('t')
g = relay.TupleGetItem(t, 0)
show(g)
import tvm
import numpy as np
from tvm import relay
do_print = [False]
def show(text):
if do_print[0]:
print("---------------------------")
print(text)
def test_func():
x = relay.var("x", shape=(3, 2))
y = relay.var("y")
one = relay.const(10e10, dtype="float32")
z = relay.add(x, one)
z = relay.add(z, z)
f = relay.Function([x, y], z)
show(z.astext())
show(f.astext())
def test_env():
x = relay.var("x", "float32")
y = relay.var("y", "float32")
z = relay.add(x, y)
z = relay.add(z, z)
f = relay.Function([x, y], z)
env = relay.Environment()
env.add("myf", f)
text = env.astext()
assert "def @myf" in text
assert "%1 = add(%0, %0) # ty=float32" in text
show(text)
def test_meta_data():
n, c, h, w = tvm.var("n"), 10, 224, 224
x = relay.var("x", shape=(n, c, h, w))
w = relay.var("w")
z = relay.nn.conv2d(x, w,
kernel_size=(3, 3),
padding=(1, 1),
channels=2)
f = relay.Function([x, w], z)
text = f.astext()
assert "channels=2" in text
assert "meta.Variable(id=0)" in text
show(text)
text = relay.const([1,2,3]).astext()
assert "meta.relay.Constant(id=0)" in text
show(text)
def test_call_attrs():
x = relay.var("x")
# non default args
z = relay.nn.softmax(x, axis=2)
assert "axis=2" in z.astext()
# default args
z = relay.nn.softmax(x)
assert "softmax(%x)" in z.astext()
# non default args
z = relay.expand_dims(x, axis=2, num_newaxis=2)
assert "num_newaxis=2" in z.astext()
def test_let_if_scope():
x = relay.var("x", "float32")
y = relay.var("y", "float32")
cond = relay.var("cond", "bool")
v1 = relay.var("v")
v2 = relay.var("v", "float32")
then_branch = relay.Let(
v1, relay.const(1, "float32"),
relay.Let(v2, x, relay.subtract(v1, v2)))
v3 = relay.var("v")
let2 = relay.Let(v3, y, v3)
else_branch = relay.add(let2, let2)
result = relay.If(cond, then_branch, else_branch)
f = relay.Function([x, y, cond], result)
text = f.astext()
assert text.count("{") == 4
assert "%cond: bool" in text
show(f.astext())
if __name__ == "__main__":
do_print[0] = True
test_let_if_scope()
test_func()
test_env()
test_meta_data()
test_call_attrs()
...@@ -10,7 +10,7 @@ def test_well_formed(): ...@@ -10,7 +10,7 @@ def test_well_formed():
let = relay.Let(x, v, x) let = relay.Let(x, v, x)
assert well_formed(let) assert well_formed(let)
assert not well_formed(relay.Let(x, v, let)) assert not well_formed(relay.Let(x, v, let))
f = relay.Function([x], ty, x) f = relay.Function([x], x, ty)
assert well_formed(f) assert well_formed(f)
# this test should pass in case of weak uniqueness (only test for shadowing) # this test should pass in case of weak uniqueness (only test for shadowing)
# but we want all binder to be distinct from each other. # but we want all binder to be distinct from each other.
......
...@@ -262,49 +262,49 @@ def test_function_alpha_equal(): ...@@ -262,49 +262,49 @@ def test_function_alpha_equal():
basic_args = [relay.Var("v3", tt1), relay.Var("v4", tt2)] basic_args = [relay.Var("v3", tt1), relay.Var("v4", tt2)]
basic_tps = [tp1, tp2] basic_tps = [tp1, tp2]
func = relay.Function([v1, v2], func = relay.Function([v1, v2], v1,
tt2, v1, basic_tps) tt2, basic_tps)
mapped = relay.Function(basic_args, tt2, basic_args[0], basic_tps) mapped = relay.Function(basic_args, basic_args[0], tt2, basic_tps)
assert alpha_equal(func, mapped) assert alpha_equal(func, mapped)
fewer_params = relay.Function([relay.Var("v4", tt2)], tt2, v4, basic_tps) fewer_params = relay.Function([relay.Var("v4", tt2)], v4, tt2, basic_tps)
assert not alpha_equal(func, fewer_params) assert not alpha_equal(func, fewer_params)
more_params = relay.Function([relay.Var("v3", tt1), more_params = relay.Function([relay.Var("v3", tt1),
relay.Var("v4", tt2), relay.Var("v4", tt2),
relay.Var("v2", tt2)], tt2, v4, basic_tps) relay.Var("v2", tt2)], v4, tt2, basic_tps)
assert not alpha_equal(func, more_params) assert not alpha_equal(func, more_params)
params_unordered = relay.Function([v2, v1], params_unordered = relay.Function([v2, v1], v1,
tt2, v1, basic_tps) tt2, basic_tps)
assert not alpha_equal(func, params_unordered) assert not alpha_equal(func, params_unordered)
params_mismatch = relay.Function([v1, v3], params_mismatch = relay.Function([v1, v3], v1,
tt2, v1, basic_tps) tt2, basic_tps)
assert not alpha_equal(func, params_mismatch) assert not alpha_equal(func, params_mismatch)
# also would not typecheck # also would not typecheck
ret_type_mismatch = relay.Function(basic_args, tt1, v4, basic_tps) ret_type_mismatch = relay.Function(basic_args, v4, tt1, basic_tps)
assert not alpha_equal(func, ret_type_mismatch) assert not alpha_equal(func, ret_type_mismatch)
# also mis-typed # also mis-typed
different_body = relay.Function(basic_args, tt2, v3, basic_tps) different_body = relay.Function(basic_args, v3, tt2, basic_tps)
assert not alpha_equal(func, different_body) assert not alpha_equal(func, different_body)
fewer_type_params = relay.Function(basic_args, tt2, v4, [tp1]) fewer_type_params = relay.Function(basic_args, v4, tt2, [tp1])
assert not alpha_equal(func, fewer_type_params) assert not alpha_equal(func, fewer_type_params)
more_type_params = relay.Function(basic_args, tt2, v4, [tp1, tp2, tp3]) more_type_params = relay.Function(basic_args, v4, tt2, [tp1, tp2, tp3])
assert not alpha_equal(func, more_type_params) assert not alpha_equal(func, more_type_params)
type_params_unordered = relay.Function(basic_args, tt2, v4, [tp2, tp1]) type_params_unordered = relay.Function(basic_args, v4, tt2, [tp2, tp1])
assert not alpha_equal(func, type_params_unordered) assert not alpha_equal(func, type_params_unordered)
different_type_params = relay.Function(basic_args, tt2, v4, [tp3, tp4]) different_type_params = relay.Function(basic_args, v4, tt2, [tp3, tp4])
assert not alpha_equal(func, different_type_params) assert not alpha_equal(func, different_type_params)
# a well-typed example that also differs in body, ret type, and type params # a well-typed example that also differs in body, ret type, and type params
tupled_example = relay.Function(basic_args, tt3, relay.Tuple([v3, v4])) tupled_example = relay.Function(basic_args, relay.Tuple([v3, v4]), tt3)
assert not alpha_equal(func, tupled_example) assert not alpha_equal(func, tupled_example)
......
...@@ -59,7 +59,7 @@ def test_recursion(): ...@@ -59,7 +59,7 @@ def test_recursion():
n = relay.Var("n", e.int32) n = relay.Var("n", e.int32)
data = relay.Var("data", e.float32) data = relay.Var("data", e.float32)
funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data))) funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data)))
value = relay.Function([n, data], e.float32, funcbody, []) value = relay.Function([n, data], funcbody, e.float32, [])
orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0))) orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)))
assert alpha_equal(dead_code_elimination(orig), orig) assert alpha_equal(dead_code_elimination(orig), orig)
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three) assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three)
......
...@@ -13,7 +13,7 @@ def test_free_vars(): ...@@ -13,7 +13,7 @@ def test_free_vars():
let = relay.Let(x, v, x) let = relay.Let(x, v, x)
fvx = free_vars(let) fvx = free_vars(let)
assert len(free_vars(let)) == 0 assert len(free_vars(let)) == 0
f = relay.Function([x], ty, x) f = relay.Function([x], x, ty)
assert len(free_vars(f)) == 0 assert len(free_vars(f)) == 0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment