Unverified Commit 9d44279e by Tianqi Chen Committed by GitHub

[RELAY][Refactor] TextPrinter, move ret_type after body in Function. (#1918)

parent 4c13ee22
......@@ -131,6 +131,13 @@ class BaseAttrsNode : public Node {
*/
inline void PrintDocString(std::ostream &os) const; // NOLINT(*)
/*!
* \brief Visit attributes that do not equal the default value.
*
* \note This is useful to extract fields for concise printing.
* \param v The visitor
*/
TVM_DLL virtual void VisitNonDefaultAttrs(AttrVisitor* v) = 0;
/*!
* \brief Get the field information
* \return The fields in the Attrs.
*/
......@@ -199,6 +206,7 @@ class DictAttrsNode : public BaseAttrsNode {
TVM_DLL static Attrs make(Map<std::string, NodeRef> dict);
// implementations
void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
Array<AttrFieldInfo> ListFieldInfo() const final;
bool ContentEqual(const Node* other) const final;
......@@ -300,15 +308,15 @@ struct AttrNopEntry {
return *this;
}
template<typename T>
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED T value) {
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
return *this;
}
template<typename T>
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) {
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
return *this;
}
template<typename T>
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) {
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
return *this;
}
};
......@@ -603,7 +611,7 @@ class AttrDocEntry {
return *this;
}
template<typename T>
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED T value) {
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
std::ostringstream os;
os << info_->type_info << ", default=" << value;
info_->type_info = os.str();
......@@ -649,6 +657,57 @@ class AttrExistVisitor {
return AttrNopEntry();
}
};
template<typename T>
struct AttrTriggerNonDefaultEntry {
using TSelf = AttrTriggerNonDefaultEntry<T>;
// constructor
AttrTriggerNonDefaultEntry(
AttrVisitor* visitor, const char* key, T* data)
: visitor_(visitor), key_(key), data_(data) {}
~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION {
if (trigger_) {
visitor_->Visit(key_, data_);
}
}
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
return *this;
}
TSelf& set_default(const T& value) {
if (AttrsEqual()(value, *data_)) {
trigger_ = false;
}
return *this;
}
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
return *this;
}
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
return *this;
}
private:
AttrVisitor* visitor_;
const char * key_;
T *data_;
bool trigger_{true};
};
class AttrNonDefaultVisitor {
public:
explicit AttrNonDefaultVisitor(AttrVisitor* visitor)
: visitor_(visitor) {
}
template<typename T>
AttrTriggerNonDefaultEntry<T>
operator()(const char* key, T* value) {
return AttrTriggerNonDefaultEntry<T>(visitor_, key, value);
}
private:
AttrVisitor* visitor_;
};
} // namespace detail
/*!
......@@ -665,6 +724,11 @@ class AttrsNode : public BaseAttrsNode {
self()->__VisitAttrs__(vis);
}
void VisitNonDefaultAttrs(AttrVisitor* v) final {
detail::AttrNonDefaultVisitor vis(v);
self()->__VisitAttrs__(vis);
}
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {
CHECK_EQ(args.size() % 2, 0);
const int kLinearSearchBound = 16;
......
......@@ -13,7 +13,7 @@ namespace tvm {
namespace relay {
/*! \brief Attributes used in convolution operators */
struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> {
struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
......@@ -25,7 +25,7 @@ struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> {
std::string out_layout;
DataType out_dtype;
TVM_DECLARE_ATTRS(ConvAttrs, "relay.attrs.ConvAttrs") {
TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
......@@ -55,14 +55,14 @@ struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> {
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("__undef__")
TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(Int(0))
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
......@@ -123,7 +123,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_dtype)
.set_default(Int(0))
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
......
......@@ -78,7 +78,7 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
.describe("Target shape.");
TVM_ATTR_FIELD(dtype)
.describe("Target data type.")
.set_default(Int(0));
.set_default(NullValue<DataType>());
}
}; // struct InitOpAttrs
......
......@@ -181,8 +181,6 @@ class FunctionNode : public ExprNode {
public:
/*! \brief Function parameters */
tvm::Array<Var> params;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*!
* \brief
* The expression which represents the computation of the function,
......@@ -190,6 +188,8 @@ class FunctionNode : public ExprNode {
* or sub-expressions may reference the type variables.
*/
Expr body;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*!
* \brief Type parameters of the function.
* Enables the function to vary its type based on these.
......@@ -201,8 +201,8 @@ class FunctionNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("params", &params);
v->Visit("ret_type", &ret_type);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
......@@ -217,8 +217,8 @@ class FunctionNode : public ExprNode {
TVM_DLL FuncType func_type_annotation() const;
TVM_DLL static Function make(tvm::Array<Var> params,
Type ret_type,
Expr body,
Type ret_type,
tvm::Array<TypeParam> ty_params);
static constexpr const char* _type_key = "relay.Function";
......
......@@ -48,6 +48,11 @@ class OpNode : public relay::ExprNode {
*/
std::string attrs_type_key;
/*!
* \brief attribute type index,
* this field varies in each run and is not exposed to frontend.
*/
uint32_t attrs_type_index{0};
/*!
* \brief number of input arguments to the operator,
* -1 means it is variable length
*/
......@@ -416,6 +421,7 @@ inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*)
inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*)
const std::string& type_key) {
get()->attrs_type_key = type_key;
get()->attrs_type_index = Node::TypeKey2Index(type_key.c_str());
return *this;
}
......
......@@ -14,13 +14,15 @@ from . import libinfo
#----------------------------
if sys.version_info[0] == 3:
string_types = (str,)
numeric_types = (float, int, np.float32, np.int32)
integer_types = (int, np.int32)
numeric_types = integer_types + (float, np.float32)
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8')
else:
string_types = (basestring,)
numeric_types = (float, int, long, np.float32, np.int32)
integer_types = (int, long, np.int32)
numeric_types = integer_types + (float, np.float32)
py_str = lambda x: x
......
# pylint: disable=wildcard-import, redefined-builtin
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler."""
from . import base
from . import ty
......@@ -19,6 +19,9 @@ from . import image
# Span
Span = base.Span
# Env
Environment = env.Environment
# Type
Type = ty.Type
TupleType = ty.TupleType
......@@ -40,3 +43,7 @@ Call = expr.Call
Let = expr.Let
If = expr.If
TupleGetItem = expr.TupleGetItem
# helper functions
var = expr.var
const = expr.const
......@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs
from .._ffi.node import NodeBase, register_node as _register_tvm_node
from . import _make
from . import _expr
NodeBase = NodeBase
......@@ -20,7 +21,19 @@ def register_relay_node(type_key=None):
return _register_tvm_node(type_key)
class RelayNode(NodeBase):
def astext(self):
"""Get the text format of the expression.
Returns
-------
text : str
The text format of the expression.
"""
return _expr._text_print(self)
@register_relay_node
class Span(NodeBase):
class Span(RelayNode):
def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global environment storing everything needed to interpret or compile a Relay program."""
from .base import register_relay_node, NodeBase
from .base import register_relay_node, RelayNode
from . import _make
from . import _env
@register_relay_node
class Environment(NodeBase):
class Environment(RelayNode):
"""The global Relay environment containing functions,
options and more.
"""
......
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression nodes of Relay."""
from __future__ import absolute_import
from .base import NodeBase, register_relay_node
from . import _expr
import numpy as _np
from .base import RelayNode, register_relay_node
from . import _make
from . import ty as _ty
from .._ffi import base as _base, node as _node
from .. import nd as _nd
from .. import convert
class Expr(NodeBase):
class Expr(RelayNode):
"""The base type for all Relay expressions."""
@property
def checked_type(self):
......@@ -56,7 +60,7 @@ class Tuple(Expr):
@register_relay_node
class Var(Expr):
"""A local variable in Tvm.Relay.
"""A local variable in Relay.
Local variable can be used to declare input
arguments to a function, or intermediate variables.
......@@ -101,26 +105,26 @@ class Function(Expr):
params: List[tvm.relay.Var]
List of input parameters to the function.
ret_type: tvm.relay.Type
The return type annotation of the function.
body: tvm.relay.Expr
The body of the function.
ret_type: Optional[tvm.relay.Type]
The return type annotation of the function.
type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only
used in advanced usecase of template functions.
"""
def __init__(self,
params,
ret_type,
body,
ret_type=None,
type_params=None):
if type_params is None:
type_params = convert([])
self.__init_handle_by_constructor__(
_make.Function, params, ret_type, body, type_params)
_make.Function, params, body, ret_type, type_params)
@register_relay_node
......@@ -158,7 +162,7 @@ class Let(Expr):
Parameters
----------
var: tvm.relay.Var
variable: tvm.relay.Var
The local variable to be bound.
value: tvm.relay.Expr
......@@ -167,9 +171,9 @@ class Let(Expr):
body: tvm.relay.Expr
The body of the let binding.
"""
def __init__(self, var, value, body):
def __init__(self, variable, value, body):
self.__init_handle_by_constructor__(
_make.Let, var, value, body)
_make.Let, variable, value, body)
@register_relay_node
......@@ -208,4 +212,105 @@ class TupleGetItem(Expr):
self.__init_handle_by_constructor__(
_make.TupleGetItem, tuple_value, index)
debug_print = _expr._debug_print
class TupleWrapper(_node.NodeGeneric):
"""TupleWrapper.
This class is a Python wrapper for a Relay tuple of known size.
It allows for accessing the fields of the Relay tuple as though
it were a Python tuple.
Parameters
----------
tuple_value: tvm.relay.Expr
The input tuple
size: int
The size of the tuple.
"""
def __init__(self, tuple_value, size):
self.tuple_value = tuple_value
self.size = size
def asnode(self):
"""Returns the underlying Relay tuple if this wrapper is passed
as an argument to an FFI function."""
return self.tuple_value
def __getitem__(self, key):
return self.tuple_value.fields[key]
def __len__(self):
return len(self.tuple_value.fields)
def var(name_hint,
type_annotation=None,
shape=None,
dtype="float32"):
"""Create a new tvm.relay.Var.
This is a simple wrapper function that allows specify
shape and dtype directly.
Parameters
----------
name_hint: str
The name of the variable.
This name only acts as a hint, and is not used
for equality.
type_annotation: Optional[tvm.relay.Type, str]
The type annotation on the variable.
When type_annotation is a str, we will create a scalar variable.
shape: Optional[List[tvm.Expr]]
The shape of the tensor type.
dtype: str, optional
The data type of the tensor.
Examples
--------
.. code-block:: python
# The following 4 lines are equivalent to each other
x = tvm.relay.Var("x", tvm.relay.TensorType([1, 2]))
x = tvm.relay.var("x", tvm.relay.TensorType([1, 2]))
x = tvm.relay.var("x", shape=[1, 2])
x = tvm.relay.var("x", shape=[1, 2], dtype="float32")
# The following 2 lines are equivalent to each other.
y = tvm.relay.var("x", "float32")
y = tvm.relay.var("x", shape=(), dtype="float32")
"""
if type_annotation is not None and shape is not None:
raise ValueError("Can only specify either type_annotation or shape.")
if shape is not None:
type_annotation = _ty.TensorType(shape, dtype)
elif isinstance(type_annotation, str):
type_annotation = _ty.TensorType((), type_annotation)
return Var(name_hint, type_annotation)
def const(value, dtype=None):
"""Create a constant value.
Parameters
----------
value: Union[bool, int, float, numpy.ndarray, tvm.nd.NDArray]
The constant value.
dtype: str, optional
The data type of the value.
"""
if isinstance(value, _base.numeric_types):
value = _np.array(value, dtype=dtype)
elif isinstance(value, (bool, list)):
value = _np.array(value, dtype=dtype)
if isinstance(value, (_np.ndarray, _np.generic)):
value = _nd.array(value)
if not isinstance(value, _nd.NDArray):
raise ValueError("value has to be scalar or NDArray")
return Constant(value)
......@@ -11,32 +11,6 @@ from .expr import Expr, Constant, Let, Var, Function, If
from .env import Environment
class TupleWrapper(tvm._ffi.node.NodeGeneric):
"""TupleWrapper.
This class is a Python wrapper for a Relay tuple of known size.
It allows for accessing the fields of the Relay tuple as though
it were a Python tuple.
"""
def __init__(self, tuple_value, size):
self.tuple_value = tuple_value
self.size = size
def asnode(self):
"""Returns the underlying Relay tuple if this wrapper is passed
as an argument to an FFI function."""
return self.tuple_value
def __getitem__(self, key):
return self.tuple_value.fields[key]
def __len__(self):
return len(self.tuple_value.fields)
def _convert_to_value(arg, ctxt=tvm.cpu(0)):
# type: (Any, tvm.Context) -> tvm.nd.NDArray
"""Convert Python values into the appropriate types
......@@ -132,8 +106,8 @@ class PartialFunc(object):
"""Converts a PartialFunc into a :py:class:`~relay.Function`."""
return Function(
self.params,
self.ret_type,
self.body,
self.ret_type,
self.type_params)
#pylint: disable=invalid-name
......@@ -325,7 +299,7 @@ class IRBuilder(object):
def _on_exit():
bindings, _, _, ret_value = self.exit_scope()
exp = _mk_let(bindings, ret_value)
self.env.add(name, Function(params, ret_type, exp))
self.env.add(name, Function(params, exp, ret_type))
return WithScope(10, _on_exit)
......
"""Neural network operations."""
from __future__ import absolute_import as _abs
from tvm.relay.ir_builder import TupleWrapper
from ...expr import TupleWrapper
from . import _make
......@@ -145,7 +145,7 @@ def conv2d_transpose(data,
weight_layout, output_padding, out_dtype)
def softmax(data, axis):
def softmax(data, axis=1):
r"""Computes softmax.
.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
......@@ -158,7 +158,7 @@ def softmax(data, axis):
data: relay.Expr
The input data to the operator.
axis: int
axis: int, optional
The axis to sum over when computing softmax
Returns
......
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language."""
from enum import IntEnum
from .base import NodeBase, register_relay_node
from .base import RelayNode, register_relay_node
from . import _make
class Type(NodeBase):
class Type(RelayNode):
"""The base type for all Relay types."""
def __eq__(self, other):
......@@ -21,27 +21,25 @@ class Type(NodeBase):
"""Compares two Relay types by referential equality."""
return super().__eq__(other)
@register_relay_node
class TensorType(Type):
"""A concrete TensorType in Relay, see tvm/relay/type.h for more details.
"""A concrete TensorType in Relay.
This is the type assigned to tensor's with a known dype and shape. For
example a tensor of `float32` and `(5, 5)`.
"""
def __init__(self, shape, dtype):
"""Construct a tensor type.
Parameters
----------
shape: list of tvm.Expr
dtype: str
Parameters
----------
shape: List[tvm.Expr]
The shape of the Tensor
Returns
-------
tensor_type: The TensorType
"""
self.__init_handle_by_constructor__(_make.TensorType, shape, dtype)
dtype: str, optional
The content data type.
"""
def __init__(self, shape, dtype="float32"):
self.__init_handle_by_constructor__(
_make.TensorType, shape, dtype)
class Kind(IntEnum):
......
......@@ -17,11 +17,15 @@ namespace tvm {
template <typename FType>
class AttrFunctor;
#define ATTR_FUNCTOR_DEFAULT \
{ return VisitAttrDefault_(op, std::forward<Args>(args)...); }
#define ATTR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->Visit_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
return self->VisitAttr_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
}); \
// A functor for common attribute information.
......@@ -40,21 +44,21 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
* \param args Additional arguments.
* \return The result of the call
*/
virtual R Visit(const NodeRef& n, Args... args) {
virtual R VisitAttr(const NodeRef& n, Args... args) {
static FType vtable = InitVTable();
if (vtable.can_dispatch(n)) {
return vtable(n, this, std::forward<Args>(args)...);
} else {
return VisitDefault_(n, std::forward<Args>(args)...);
return VisitAttrDefault_(n.get(), std::forward<Args>(args)...);
}
}
virtual R Visit_(const ArrayNode* op, Args... args) = 0;
virtual R Visit_(const StrMapNode* op, Args... args) = 0;
virtual R Visit_(const ir::IntImm* op, Args... args) = 0;
virtual R Visit_(const ir::UIntImm* op, Args... args) = 0;
virtual R Visit_(const ir::FloatImm* op, Args... args) = 0;
virtual R Visit_(const ir::StringImm* op, Args... args) = 0;
virtual R VisitDefault_(const NodeRef& n, Args... args) = 0;
virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::UIntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttrDefault_(const Node* node, Args... args) = 0;
private:
// initialize the vtable.
......
......@@ -11,6 +11,10 @@ void DictAttrsNode::VisitAttrs(AttrVisitor* v) {
v->Visit("__dict__", &dict);
}
void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) {
v->Visit("__dict__", &dict);
}
void DictAttrsNode::InitByPackedArgs(
const runtime::TVMArgs& args, bool allow_unknown) {
for (int i = 0; i < args.size(); i += 2) {
......@@ -55,48 +59,48 @@ class AttrsEqualChecker :
if (!equal_) return false;
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
if (!this->Visit(lhs, rhs)) {
if (!this->VisitAttr(lhs, rhs)) {
equal_ = false;
}
return equal_;
}
bool VisitDefault_(const NodeRef& lhs, const NodeRef& other) final {
bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final {
if (lhs->derived_from<BaseAttrsNode>()) {
return static_cast<const BaseAttrsNode*>(lhs.get())->ContentEqual(other.get());
return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(other.get());
}
return lhs.same_as(other);
return lhs == other.get();
}
bool Visit_(const IntImm* lhs, const NodeRef& other) final {
bool VisitAttr_(const IntImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<IntImm>()) {
return lhs->value == rhs->value;
}
return false;
}
bool Visit_(const UIntImm* lhs, const NodeRef& other) final {
bool VisitAttr_(const UIntImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<UIntImm>()) {
return lhs->value == rhs->value;
}
return false;
}
bool Visit_(const FloatImm* lhs, const NodeRef& other) final {
bool VisitAttr_(const FloatImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<FloatImm>()) {
return lhs->value == rhs->value;
}
return false;
}
bool Visit_(const StringImm* lhs, const NodeRef& other) final {
bool VisitAttr_(const StringImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<StringImm>()) {
return lhs->value == rhs->value;
}
return false;
}
bool Visit_(const ArrayNode* lhs, const NodeRef& other) final {
bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<ArrayNode>()) {
if (rhs->data.size() != lhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) {
......@@ -106,7 +110,7 @@ class AttrsEqualChecker :
return true;
}
bool Visit_(const StrMapNode* lhs, const NodeRef& other) final {
bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<StrMapNode>()) {
if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) {
......@@ -127,38 +131,38 @@ class AttrContentHasher :
public:
size_t result_{0};
void VisitDefault_(const NodeRef& value) final {
void VisitAttrDefault_(const Node* value) final {
if (value->derived_from<BaseAttrsNode>()) {
Update(static_cast<const BaseAttrsNode*>(value.get())->ContentHash());
Update(static_cast<const BaseAttrsNode*>(value)->ContentHash());
} else {
Update(NodeHash()(value));
Update(NodeHash()(GetRef<NodeRef>(value)));
}
}
void Visit_(const IntImm* op) final {
void VisitAttr_(const IntImm* op) final {
Update(std::hash<int64_t>()(op->value));
}
void Visit_(const UIntImm* op) final {
void VisitAttr_(const UIntImm* op) final {
Update(std::hash<uint64_t>()(op->value));
}
void Visit_(const FloatImm* op) final {
void VisitAttr_(const FloatImm* op) final {
Update(std::hash<double>()(op->value));
}
void Visit_(const StringImm* op) final {
void VisitAttr_(const StringImm* op) final {
Update(std::hash<std::string>()(op->value));
}
void Visit_(const ArrayNode* op) final {
void VisitAttr_(const ArrayNode* op) final {
Update(op->data.size());
for (size_t i = 0; i < op->data.size(); ++i) {
this->Visit(NodeRef(op->data[i]));
this->VisitAttr(NodeRef(op->data[i]));
}
}
void Visit_(const StrMapNode* lhs) final {
void VisitAttr_(const StrMapNode* lhs) final {
using Entry = std::pair<std::string, NodePtr<Node> >;
std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
......@@ -166,7 +170,7 @@ class AttrContentHasher :
});
for (const Entry& kv : data) {
Update(std::hash<std::string>()(kv.first));
this->Visit(NodeRef(kv.second));
this->VisitAttr(NodeRef(kv.second));
}
}
......@@ -184,7 +188,7 @@ bool AttrsEqual::Equal(const NodeRef& lhs, const NodeRef& rhs) {
size_t AttrsHash::Hash(const NodeRef& node) {
if (!node.defined()) return 0;
AttrContentHasher hasher;
hasher.Visit(node);
hasher.VisitAttr(node);
return hasher.result_;
}
......
......@@ -208,6 +208,8 @@ class JSONAttrGetter : public AttrVisitor {
node_->type_key = node->type_key();
// sepcially handle global object
auto* f = dmlc::Registry<NodeFactoryReg>::Find(node_->type_key);
CHECK(f != nullptr)
<< "Node type \'" << node_->type_key << "\' is not registered in TVM";
if (f->fglobal_key != nullptr) {
node_->global_key = f->fglobal_key(node);
return;
......
......@@ -51,6 +51,8 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) {
return Span(n);
}
TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_API("relay._make.Span")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = SpanNode::make(args[0], args[1], args[2]);
......
/*!
* Copyright (c) 2018 by Contributors
* \file debug_printer.cc
* \brief A pretty printer for the Relay IR.
* As we had not determined a formal syntax yet, right now it is only for debug purpose.
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/environment.h>
#include <tvm/relay/error.h>
#include <iostream>
#include <sstream>
#include <vector>
#include <unordered_map>
#include <string>
#include <vector>
#include <iostream>
#include "../pass/type_functor.h"
#include "doc.h"
namespace tvm {
namespace relay {
using namespace tvm::runtime;
Doc KindDocify(TypeParamNode::Kind k) {
switch (k) {
case TypeParamNode::kShapeVar:
return DocOfStr("ShapeVar");
case TypeParamNode::kShape:
return DocOfStr("Shape");
case TypeParamNode::kBaseType:
return DocOfStr("BaseType");
case TypeParamNode::kType:
return DocOfStr("Type");
default:
LOG(FATAL) << "unreachable code: case not handle in kind";
throw; // log fatal throw but compiler doesnt know
}
}
template<typename T>
std::vector<Doc> MapDocify(const tvm::Array<T>& arr, const std::function<Doc(const T&)>& f) {
std::vector<Doc> vec;
for (size_t i = 0; i < arr.size(); ++i) {
vec.push_back(f(arr[i]));
}
return vec;
}
template<typename T, typename Hash = std::hash<T>, typename Eq = std::equal_to<T>>
class Counter {
std::unordered_map<T, size_t, Hash, Eq> cnt_;
public:
Counter() = default;
Counter(const Counter&) = delete;
size_t operator()(const T& t) {
auto v = cnt_.count(t) == 0 ? 0 : cnt_.at(t) + 1;
cnt_[t] = v;
return v;
}
};
std::string Mangle(const std::string& str, size_t s) {
return str + "_" + std::to_string(s);
// return s == 0 ? str : str + "_" + std::to_string(s - 1);
// the above line look prettier but is dangerous:
// suppose we have x, x, x_0. mangling will give x, x_0, x_0!
// the save approach give x_0, x_1, x_0_1, and in fact never clash:
// stripping _([0-9]*) is invert of mangle under all circumstances.
// another problem is we need to prevent Var/TypeParam/GlobalVar clashing each other.
}
constexpr size_t indent = 2;
struct TypeParamName {
bool operator==(const TypeParamName&) const {
return true;
}
};
struct mhash {
size_t operator()(const ::tvm::relay::TypeParamName&) const noexcept {
return 0;
}
};
class TypeDocifier : private TypeFunctor<Doc(const Type& n)> {
Environment env;
Counter<TypeParamName, mhash> cnt;
std::unordered_map<TypeParam, Doc, NodeHash, NodeEqual> map;
std::vector<Doc> DocifyTypeArray(const tvm::Array<Type>& arr) {
return MapDocify<Type>(arr, [=](const Type& t) { return Docify(t); });
}
std::vector<Doc> DocifyTypeParam(const tvm::Array<TypeParam>& arr) {
return MapDocify<TypeParam>(arr, [=](const TypeParam& tp) {
return Docify(tp);
});
}
std::vector<Doc> DocifyTypeConstraint(const tvm::Array<TypeConstraint>& arr) {
return MapDocify<TypeConstraint>(arr, [=](const TypeConstraint& tc) { return Docify(tc); });
}
Doc VisitType_(const TensorTypeNode* t) final {
return DocOfStr("tensor");
}
Doc VisitType_(const TypeParamNode* p) final {
auto tp = GetRef<TypeParam>(p);
if (map.count(tp) == 0) {
auto name =
DocOfStr(Mangle("tp", cnt(TypeParamName())) +
std::string(":")) +
KindDocify(p->kind);
map.insert(std::pair<TypeParam, Doc>(tp, name));
}
return map.at(tp);
}
Doc Quantify(const tvm::Array<TypeParam>& tp, const Doc& d) {
if (tp.size() == 0) {
return d;
}
return Seq("forall", DocifyTypeParam(tp), ",") + Sep() + d;
}
Doc Constraint(const tvm::Array<TypeConstraint>& tc, const Doc& d) {
if (tc.size() == 0) {
return d;
}
return Seq("(", DocifyTypeConstraint(tc), ") =>") + Sep() + d;
}
Doc VisitType_(const FuncTypeNode* f) final {
auto inner = Seq("<", DocifyTypeArray(f->arg_types), ">") + Sep() +
DocOfStr("->") + Sep() + Docify(f->ret_type);
return Group(Quantify(f->type_params,
Constraint(f->type_constraints, inner)));
}
Doc VisitType_(const TypeRelationNode* r) final {
return DocOfStr("Relation") + Seq("(", DocifyTypeArray(r->args), ")");
}
Doc VisitType_(const TupleTypeNode* t) final {
return Seq("<", DocifyTypeArray(t->fields), ">");
}
Doc VisitType_(const IncompleteTypeNode* i) final {
return DocOfStr("_");
}
public:
TypeDocifier(const Environment& env) : env(env) { }
Doc Docify(const Type& t) { return t.get() ? (*this)(t) : DocOfStr("_"); }
};
class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
Environment env;
Counter<std::string> cnt;
std::unordered_map<Var, std::string, NodeHash, NodeEqual> map;
TypeDocifier td;
std::string VarName(const Var& v) {
if (map.count(v) == 0) {
map.insert(std::pair<Var, std::string>(v, Mangle(v->name_hint, cnt(v->name_hint))));
}
return map.at(v);
}
Doc TypeAnnotation(const Doc& d, const Type& t) {
// test for t being null. probably shouldnt has null. should talk to jared.
if (!t.get() || t.as<IncompleteTypeNode>()) {
return d;
} else {
return d + DocOfStr(":") + td.Docify(t);
}
}
std::vector<Doc> DocifyExprArray(const tvm::Array<Expr>& arr) {
std::vector<Doc> vec;
for (size_t i = 0; i < arr.size(); ++i) {
vec.push_back(Docify(arr[i]));
}
return vec;
}
std::vector<Doc> DocifyParamArray(const tvm::Array<Var>& arr) {
std::vector<Doc> vec;
for (Var param : arr) {
vec.emplace_back(TypeAnnotation(DocOfStr(VarName(param)),
param->type_annotation));
}
return vec;
}
Doc VisitExpr_(const ConstantNode* c) final {
return DocOfStr("some_constant");
}
Doc VisitExpr_(const TupleNode* t) final {
return Seq("<", DocifyExprArray(t->fields), ">");
}
Doc VisitExpr_(const VarNode* v) final {
return DocOfStr(VarName(GetRef<Var>(v)));
}
Doc VisitExpr_(const GlobalVarNode* g) final {
return DocOfStr(g->name_hint);
}
Doc VisitExpr_(const FunctionNode* f) final {
return Group(TypeAnnotation(Seq("(", DocifyParamArray(f->params), ")"), f->ret_type) + Sep() +
DocOfStr("=>") + Sep() +
Block(indent, "{", Docify(f->body), "}"));
}
Doc VisitExpr_(const CallNode* c) final {
return Docify(c->op) + Seq("<", DocifyExprArray(c->args), ">");
}
Doc VisitExpr_(const LetNode* l) final {
return Group(DocOfStr("let") + Sep() +
TypeAnnotation(Docify(l->var), l->var->type_annotation) + Sep() +
DocOfStr("=") + Sep() + Docify(l->value) + DocOfStr(";") + Endl() +
Docify(l->body));
}
Doc VisitExpr_(const IfNode* i) final {
return Group(DocOfStr("if") + Sep() + Docify(i->cond) + Sep() +
Block(indent, "{", Docify(i->true_branch), "}") + Sep() +
DocOfStr("else") + Sep() +
Block(indent, "{", Docify(i->false_branch), "}"));
}
Doc VisitExpr_(const OpNode* o) final {
return DocOfStr(o->name);
}
Doc VisitExpr_(const TupleGetItemNode* g) final {
return Docify(g->tuple) + DocOfStr(std::string(".") + std::to_string(g->index));
}
public:
ExprDocifier(const Environment& env) : env(env), td(env) { }
Doc Docify(const Expr& e) { return (*this)(e); }
};
Doc DocOfExpr(const Environment& env, const Expr& expr) {
ExprDocifier d(env);
return d.Docify(expr);
}
Doc DocOfType(const Environment& env, const Type& expr) {
TypeDocifier d(env);
return d.Docify(expr);
}
RDoc ExprRDoc(const Environment& env, const Expr& expr) {
return Layout(DocOfExpr(env, expr));
}
RDoc TypeRDoc(const Environment& env, const Type& expr) {
return Layout(DocOfType(env, expr));
}
std::ostream & DebugPrint(const Environment& env, const Expr& e, std::ostream& os) {
return os << ExprRDoc(env, e);
}
std::ostream & DebugPrint(const Environment& env, const Type& t, std::ostream& os) {
return os << TypeRDoc(env, t);
}
std::string PrintExpr(const Environment& env, const Expr& e) {
std::stringstream ss;
ss << ExprRDoc(env, e);
return ss.str();
}
std::string PrintType(const Environment& env, const Type& t) {
std::stringstream ss;
ss << TypeRDoc(env, t);
return ss.str();
}
TVM_REGISTER_API("relay._expr._debug_print")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[1];
if (x.as<TypeNode>()) {
*ret = PrintType(args[0], Downcast<Type>(x));
} else {
*ret = PrintExpr(args[0], Downcast<Expr>(x));
}
});
} // namespace relay
} // namespace tvm
......@@ -100,6 +100,8 @@ void EnvironmentNode::Merge(const Environment &env) {
}
}
TVM_REGISTER_NODE_TYPE(EnvironmentNode);
TVM_REGISTER_API("relay._make.Environment")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EnvironmentNode::make(args[0]);
......
......@@ -17,6 +17,8 @@ Constant ConstantNode::make(runtime::NDArray data) {
return Constant(n);
}
TVM_REGISTER_NODE_TYPE(ConstantNode);
TVM_REGISTER_API("relay._make.Constant")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ConstantNode::make(args[0]);
......@@ -44,6 +46,8 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
return Tuple(n);
}
TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_API("relay._make.Tuple")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TupleNode::make(args[0]);
......@@ -61,6 +65,8 @@ Var VarNode::make(std::string name_hint, Type type_annotation) {
return Var(n);
}
TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = VarNode::make(args[0], args[1]);
......@@ -82,6 +88,8 @@ GlobalVar GlobalVarNode::make(std::string name_hint) {
return GlobalVar(n);
}
TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_API("relay._make.GlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = GlobalVarNode::make(args[0]);
......@@ -94,13 +102,13 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
Function FunctionNode::make(tvm::Array<Var> params,
Type ret_type,
Expr body,
Type ret_type,
tvm::Array<TypeParam> type_params) {
NodePtr<FunctionNode> n = make_node<FunctionNode>();
n->params = std::move(params);
n->ret_type = std::move(ret_type);
n->body = std::move(body);
n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params);
return Function(n);
}
......@@ -113,6 +121,8 @@ FuncType FunctionNode::func_type_annotation() const {
return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {});
}
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_API("relay._make.Function")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FunctionNode::make(args[0], args[1], args[2], args[3]);
......@@ -135,6 +145,8 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
return Call(n);
}
TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_API("relay._make.Call")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CallNode::make(args[0], args[1], args[2], args[3]);
......@@ -154,6 +166,8 @@ Let LetNode::make(Var var, Expr value, Expr body) {
return Let(n);
}
TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_API("relay._make.Let")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = LetNode::make(args[0], args[1], args[2]);
......@@ -173,6 +187,8 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
return If(n);
}
TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IfNode::make(args[0], args[1], args[2]);
});
......@@ -190,6 +206,8 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
return TupleGetItem(n);
}
TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleGetItemNode::make(args[0], args[1]);
});
......
......@@ -92,7 +92,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return FunctionNode::make(params, ret_type, body, ty_params);
return FunctionNode::make(params, body, ret_type, ty_params);
}
}
......
......@@ -22,6 +22,8 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
return TensorTypeNode::make({}, dtype);
}
TVM_REGISTER_NODE_TYPE(TensorTypeNode);
TVM_REGISTER_API("relay._make.TensorType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Array<IndexExpr> shape = args[0];
......@@ -30,8 +32,8 @@ TVM_REGISTER_API("relay._make.TensorType")
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const TensorTypeNode *node,
tvm::IRPrinter *p) {
p->stream << "TensorTypeNode(" << node->dtype << ", " << node->shape << ")";
tvm::IRPrinter *p) {
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});
TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) {
......@@ -41,6 +43,8 @@ TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) {
return TypeParam(n);
}
TVM_REGISTER_NODE_TYPE(TypeParamNode);
TVM_REGISTER_API("relay._make.TypeParam")
.set_body([](TVMArgs args, TVMRetValue *ret) {
int kind = args[1];
......@@ -61,6 +65,8 @@ IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) {
return IncompleteType(n);
}
TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_API("relay._make.IncompleteType")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[0];
......@@ -86,6 +92,8 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
return FuncType(n);
}
TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_API("relay._make.FuncType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]);
......@@ -111,6 +119,8 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func,
return TypeRelation(n);
}
TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_API("relay._make.TypeRelation")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]);
......@@ -129,6 +139,8 @@ TupleType TupleTypeNode::make(Array<Type> fields) {
return TupleType(n);
}
TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_API("relay._make.TupleType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TupleTypeNode::make(args[0]);
......
......@@ -11,7 +11,7 @@
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(ConvAttrs);
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
bool Conv2DRel(const Array<Type>& types,
int num_inputs,
......@@ -25,7 +25,7 @@ bool Conv2DRel(const Array<Type>& types,
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");
const ConvAttrs* param = attrs.as<ConvAttrs>();
const Conv2DAttrs* param = attrs.as<Conv2DAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->weight_layout);
......@@ -113,7 +113,7 @@ Expr MakeConv2D(Expr data,
std::string weight_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_node<ConvAttrs>();
auto attrs = make_node<Conv2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
......@@ -148,6 +148,7 @@ with the layer input to produce a tensor of outputs.
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.Conv2DAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
......@@ -296,6 +297,7 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.Conv2DTransposeAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
......
......@@ -78,6 +78,7 @@ RELAY_REGISTER_OP("nn.dense")
- **out**: `(x1, x2, ..., xn, units)`.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.DenseAttrs")
.set_num_inputs(2)
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("weight", "2D Tensor", "Weight matrix.")
......@@ -107,6 +108,7 @@ RELAY_REGISTER_OP("nn.leaky_relu")
`y = x > 0 ? x : alpha * x`
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.LeakyReluAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "Input data.")
.set_support_level(3)
......@@ -135,6 +137,7 @@ RELAY_REGISTER_OP("nn.softmax")
- **data**: The input data
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SoftmaxAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
......@@ -163,6 +166,7 @@ RELAY_REGISTER_OP("nn.log_softmax")
- **data**: The input data
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SoftmaxAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
......@@ -171,9 +175,9 @@ RELAY_REGISTER_OP("nn.log_softmax")
// BatchFlatten
bool BatchFlattenRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
......@@ -278,6 +282,7 @@ centered at that value (zero padding is added where necessary).
- **data**: The input tensor.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.LRNAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
......@@ -296,12 +301,12 @@ Expr MakeL2Normalize(Expr data,
}
TVM_REGISTER_API("relay.op.nn._make.l2_normalize")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeL2Normalize, args, rv);
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeL2Normalize, args, rv);
});
RELAY_REGISTER_OP("nn.l2_normalize")
.describe(R"code(L2 Normalization layer.
.describe(R"code(L2 Normalization layer.
Normalizes along dimension axis using an L2 norm
......@@ -352,6 +357,7 @@ During training, each element of the input is set to zero with probability ``p``
The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input unchanged.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.DropoutAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "Input to which dropout will be applied.")
.set_support_level(1)
......@@ -478,6 +484,7 @@ axis to be the last item in the input shape.
.. note::
This operator can be optimized away for inference.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.BatchNormAttrs")
.set_num_inputs(5)
.add_argument("data", "Tensor", "Input to which batch_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
......
......@@ -60,7 +60,7 @@ bool PadRel(const Array<Type>& types,
}
// Handler to create a call to the padding op used by front-end FFI
Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) {
Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) {
auto attrs = make_node<PadAttrs>();
attrs->pad_value = pad_value;
attrs->pad_width = std::move(pad_width);
......
......@@ -76,6 +76,7 @@ RELAY_REGISTER_OP("expand_dims")
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.ExpandDimsAttrs")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.add_type_rel("ExpandDims", ExpandDimsRel);
......@@ -481,6 +482,7 @@ RELAY_REGISTER_OP("zeros")
.describe(R"code(Fill array with zeros.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.InitOpAttrs")
.set_num_inputs(0)
.set_support_level(3)
.add_type_rel("InitOp", InitOpRel);
......@@ -503,6 +505,7 @@ RELAY_REGISTER_OP("ones")
.describe(R"code(Fill array with ones.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.InitOpAttrs")
.set_num_inputs(0)
.set_support_level(3)
.add_type_rel("InitOp", InitOpRel);
......@@ -697,6 +700,7 @@ RELAY_REGISTER_OP("squeeze")
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.SqueezeAttrs")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Squeeze", SqueezeRel);
......
......@@ -74,7 +74,10 @@ class CalcDep : private ExprMutator {
}
Expr VisitExpr_(const FunctionNode* f) final {
return FunctionNode::make(f->params, f->ret_type, Eliminate(f->body), f->type_params);
return FunctionNode::make(f->params,
Eliminate(f->body),
f->ret_type,
f->type_params);
}
// generate the let list from dependency graph
......
......@@ -20,6 +20,7 @@ class TypeFunctor;
#define TYPE_FUNCTOR_DEFAULT \
{ return VisitTypeDefault_(op, std::forward<Args>(args)...); }
#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
......
import tvm
from tvm import relay
from tvm.relay.expr import debug_print
from tvm.relay.ir_builder import IRBuilder
ib = IRBuilder()
def show(e):
r = debug_print(ib.env, e)
assert r is not None
def test_constant():
arr = tvm.nd.array(10)
const = relay.Constant(arr)
show(const)
# should print the array inside?
def test_tuple():
fields = tvm.convert([])
tup = relay.Tuple(fields)
show(tup)
def test_local_var():
name_hint = 's'
lv = relay.Var(name_hint)
show(lv)
def test_dup_var():
lv = relay.Var('s')
rv = relay.Var('s')
show(relay.Tuple([lv, rv]))
def test_large_dup_var():
av = relay.Var('s')
bv = relay.Var('s')
cv = relay.Var('s')
show(relay.Tuple([av, bv, cv]))
def test_global_var():
name_hint = 'g'
gv = relay.GlobalVar(name_hint)
gv.name_hint == name_hint
show(gv)
def test_function():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Var(n) for n in param_names])
ret_type = None
body = params[0]
type_params = tvm.convert([])
fn = relay.Function(params, ret_type, body, type_params)
show(fn)
def test_call():
op = relay.Var('f')
arg_names = ['a', 'b', 'c', 'd']
args = tvm.convert([relay.Var(n) for n in arg_names])
call = relay.Call(op, args, None, None)
show(call)
def test_let():
ty = relay.ty.TensorType((10, 20), 'float32')
lv = relay.Var('x', ty)
arr = tvm.nd.array(10)
value = relay.Constant(arr)
let = relay.Let(lv, value, lv)
show(let)
def test_if():
cond = relay.Var('cond')
left = relay.Var('left')
right = relay.Var('right')
ife = relay.If(cond, left, right)
show(ife)
def test_tuple_get_item():
t = relay.Var('t')
g = relay.TupleGetItem(t, 0)
show(g)
import tvm
import numpy as np
from tvm import relay
do_print = [False]
def show(text):
if do_print[0]:
print("---------------------------")
print(text)
def test_func():
x = relay.var("x", shape=(3, 2))
y = relay.var("y")
one = relay.const(10e10, dtype="float32")
z = relay.add(x, one)
z = relay.add(z, z)
f = relay.Function([x, y], z)
show(z.astext())
show(f.astext())
def test_env():
x = relay.var("x", "float32")
y = relay.var("y", "float32")
z = relay.add(x, y)
z = relay.add(z, z)
f = relay.Function([x, y], z)
env = relay.Environment()
env.add("myf", f)
text = env.astext()
assert "def @myf" in text
assert "%1 = add(%0, %0) # ty=float32" in text
show(text)
def test_meta_data():
n, c, h, w = tvm.var("n"), 10, 224, 224
x = relay.var("x", shape=(n, c, h, w))
w = relay.var("w")
z = relay.nn.conv2d(x, w,
kernel_size=(3, 3),
padding=(1, 1),
channels=2)
f = relay.Function([x, w], z)
text = f.astext()
assert "channels=2" in text
assert "meta.Variable(id=0)" in text
show(text)
text = relay.const([1,2,3]).astext()
assert "meta.relay.Constant(id=0)" in text
show(text)
def test_call_attrs():
x = relay.var("x")
# non default args
z = relay.nn.softmax(x, axis=2)
assert "axis=2" in z.astext()
# default args
z = relay.nn.softmax(x)
assert "softmax(%x)" in z.astext()
# non default args
z = relay.expand_dims(x, axis=2, num_newaxis=2)
assert "num_newaxis=2" in z.astext()
def test_let_if_scope():
x = relay.var("x", "float32")
y = relay.var("y", "float32")
cond = relay.var("cond", "bool")
v1 = relay.var("v")
v2 = relay.var("v", "float32")
then_branch = relay.Let(
v1, relay.const(1, "float32"),
relay.Let(v2, x, relay.subtract(v1, v2)))
v3 = relay.var("v")
let2 = relay.Let(v3, y, v3)
else_branch = relay.add(let2, let2)
result = relay.If(cond, then_branch, else_branch)
f = relay.Function([x, y, cond], result)
text = f.astext()
assert text.count("{") == 4
assert "%cond: bool" in text
show(f.astext())
if __name__ == "__main__":
do_print[0] = True
test_let_if_scope()
test_func()
test_env()
test_meta_data()
test_call_attrs()
......@@ -10,7 +10,7 @@ def test_well_formed():
let = relay.Let(x, v, x)
assert well_formed(let)
assert not well_formed(relay.Let(x, v, let))
f = relay.Function([x], ty, x)
f = relay.Function([x], x, ty)
assert well_formed(f)
# this test should pass in case of weak uniqueness (only test for shadowing)
# but we want all binder to be distinct from each other.
......
......@@ -262,49 +262,49 @@ def test_function_alpha_equal():
basic_args = [relay.Var("v3", tt1), relay.Var("v4", tt2)]
basic_tps = [tp1, tp2]
func = relay.Function([v1, v2],
tt2, v1, basic_tps)
mapped = relay.Function(basic_args, tt2, basic_args[0], basic_tps)
func = relay.Function([v1, v2], v1,
tt2, basic_tps)
mapped = relay.Function(basic_args, basic_args[0], tt2, basic_tps)
assert alpha_equal(func, mapped)
fewer_params = relay.Function([relay.Var("v4", tt2)], tt2, v4, basic_tps)
fewer_params = relay.Function([relay.Var("v4", tt2)], v4, tt2, basic_tps)
assert not alpha_equal(func, fewer_params)
more_params = relay.Function([relay.Var("v3", tt1),
relay.Var("v4", tt2),
relay.Var("v2", tt2)], tt2, v4, basic_tps)
relay.Var("v2", tt2)], v4, tt2, basic_tps)
assert not alpha_equal(func, more_params)
params_unordered = relay.Function([v2, v1],
tt2, v1, basic_tps)
params_unordered = relay.Function([v2, v1], v1,
tt2, basic_tps)
assert not alpha_equal(func, params_unordered)
params_mismatch = relay.Function([v1, v3],
tt2, v1, basic_tps)
params_mismatch = relay.Function([v1, v3], v1,
tt2, basic_tps)
assert not alpha_equal(func, params_mismatch)
# also would not typecheck
ret_type_mismatch = relay.Function(basic_args, tt1, v4, basic_tps)
ret_type_mismatch = relay.Function(basic_args, v4, tt1, basic_tps)
assert not alpha_equal(func, ret_type_mismatch)
# also mis-typed
different_body = relay.Function(basic_args, tt2, v3, basic_tps)
different_body = relay.Function(basic_args, v3, tt2, basic_tps)
assert not alpha_equal(func, different_body)
fewer_type_params = relay.Function(basic_args, tt2, v4, [tp1])
fewer_type_params = relay.Function(basic_args, v4, tt2, [tp1])
assert not alpha_equal(func, fewer_type_params)
more_type_params = relay.Function(basic_args, tt2, v4, [tp1, tp2, tp3])
more_type_params = relay.Function(basic_args, v4, tt2, [tp1, tp2, tp3])
assert not alpha_equal(func, more_type_params)
type_params_unordered = relay.Function(basic_args, tt2, v4, [tp2, tp1])
type_params_unordered = relay.Function(basic_args, v4, tt2, [tp2, tp1])
assert not alpha_equal(func, type_params_unordered)
different_type_params = relay.Function(basic_args, tt2, v4, [tp3, tp4])
different_type_params = relay.Function(basic_args, v4, tt2, [tp3, tp4])
assert not alpha_equal(func, different_type_params)
# a well-typed example that also differs in body, ret type, and type params
tupled_example = relay.Function(basic_args, tt3, relay.Tuple([v3, v4]))
tupled_example = relay.Function(basic_args, relay.Tuple([v3, v4]), tt3)
assert not alpha_equal(func, tupled_example)
......
......@@ -59,7 +59,7 @@ def test_recursion():
n = relay.Var("n", e.int32)
data = relay.Var("data", e.float32)
funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data)))
value = relay.Function([n, data], e.float32, funcbody, [])
value = relay.Function([n, data], funcbody, e.float32, [])
orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)))
assert alpha_equal(dead_code_elimination(orig), orig)
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three)
......
......@@ -13,7 +13,7 @@ def test_free_vars():
let = relay.Let(x, v, x)
fvx = free_vars(let)
assert len(free_vars(let)) == 0
f = relay.Function([x], ty, x)
f = relay.Function([x], x, ty)
assert len(free_vars(f)) == 0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment