Unverified Commit 9d44279e by Tianqi Chen Committed by GitHub

[RELAY][Refactor] TextPrinter, move ret_type after body in Function. (#1918)

parent 4c13ee22
......@@ -131,6 +131,13 @@ class BaseAttrsNode : public Node {
*/
inline void PrintDocString(std::ostream &os) const; // NOLINT(*)
/*!
* \brief Visit attributes that do not equal the default value.
*
* \note This is useful to extract fields for concise printing.
* \param v The visitor
*/
TVM_DLL virtual void VisitNonDefaultAttrs(AttrVisitor* v) = 0;
/*!
* \brief Get the field information
* \return The fields in the Attrs.
*/
......@@ -199,6 +206,7 @@ class DictAttrsNode : public BaseAttrsNode {
TVM_DLL static Attrs make(Map<std::string, NodeRef> dict);
// implementations
void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
Array<AttrFieldInfo> ListFieldInfo() const final;
bool ContentEqual(const Node* other) const final;
......@@ -300,15 +308,15 @@ struct AttrNopEntry {
return *this;
}
template<typename T>
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED T value) {
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
return *this;
}
template<typename T>
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) {
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
return *this;
}
template<typename T>
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) {
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
return *this;
}
};
......@@ -603,7 +611,7 @@ class AttrDocEntry {
return *this;
}
template<typename T>
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED T value) {
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
std::ostringstream os;
os << info_->type_info << ", default=" << value;
info_->type_info = os.str();
......@@ -649,6 +657,57 @@ class AttrExistVisitor {
return AttrNopEntry();
}
};
template<typename T>
struct AttrTriggerNonDefaultEntry {
using TSelf = AttrTriggerNonDefaultEntry<T>;
// constructor
AttrTriggerNonDefaultEntry(
AttrVisitor* visitor, const char* key, T* data)
: visitor_(visitor), key_(key), data_(data) {}
~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION {
if (trigger_) {
visitor_->Visit(key_, data_);
}
}
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) {
return *this;
}
TSelf& set_default(const T& value) {
if (AttrsEqual()(value, *data_)) {
trigger_ = false;
}
return *this;
}
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
return *this;
}
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
return *this;
}
private:
AttrVisitor* visitor_;
const char * key_;
T *data_;
bool trigger_{true};
};
class AttrNonDefaultVisitor {
public:
explicit AttrNonDefaultVisitor(AttrVisitor* visitor)
: visitor_(visitor) {
}
template<typename T>
AttrTriggerNonDefaultEntry<T>
operator()(const char* key, T* value) {
return AttrTriggerNonDefaultEntry<T>(visitor_, key, value);
}
private:
AttrVisitor* visitor_;
};
} // namespace detail
/*!
......@@ -665,6 +724,11 @@ class AttrsNode : public BaseAttrsNode {
self()->__VisitAttrs__(vis);
}
void VisitNonDefaultAttrs(AttrVisitor* v) final {
detail::AttrNonDefaultVisitor vis(v);
self()->__VisitAttrs__(vis);
}
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {
CHECK_EQ(args.size() % 2, 0);
const int kLinearSearchBound = 16;
......
......@@ -13,7 +13,7 @@ namespace tvm {
namespace relay {
/*! \brief Attributes used in convolution operators */
struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> {
struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
......@@ -25,7 +25,7 @@ struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> {
std::string out_layout;
DataType out_dtype;
TVM_DECLARE_ATTRS(ConvAttrs, "relay.attrs.ConvAttrs") {
TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
......@@ -55,14 +55,14 @@ struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> {
.describe("Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("__undef__")
TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Default to be same as input layout.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(Int(0))
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
......@@ -123,7 +123,7 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_dtype)
.set_default(Int(0))
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
......
......@@ -78,7 +78,7 @@ struct InitOpAttrs : public tvm::AttrsNode<InitOpAttrs> {
.describe("Target shape.");
TVM_ATTR_FIELD(dtype)
.describe("Target data type.")
.set_default(Int(0));
.set_default(NullValue<DataType>());
}
}; // struct InitOpAttrs
......
......@@ -181,8 +181,6 @@ class FunctionNode : public ExprNode {
public:
/*! \brief Function parameters */
tvm::Array<Var> params;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*!
* \brief
* The expression which represents the computation of the function,
......@@ -190,6 +188,8 @@ class FunctionNode : public ExprNode {
* or sub-expressions may reference the type variables.
*/
Expr body;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*!
* \brief Type parameters of the function.
* Enables the function to vary its type based on these.
......@@ -201,8 +201,8 @@ class FunctionNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("params", &params);
v->Visit("ret_type", &ret_type);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
......@@ -217,8 +217,8 @@ class FunctionNode : public ExprNode {
TVM_DLL FuncType func_type_annotation() const;
TVM_DLL static Function make(tvm::Array<Var> params,
Type ret_type,
Expr body,
Type ret_type,
tvm::Array<TypeParam> ty_params);
static constexpr const char* _type_key = "relay.Function";
......
......@@ -48,6 +48,11 @@ class OpNode : public relay::ExprNode {
*/
std::string attrs_type_key;
/*!
* \brief attribute type index,
* this field varies in each run and is not exposed to frontend.
*/
uint32_t attrs_type_index{0};
/*!
* \brief number of input arguments to the operator,
* -1 means it is variable length
*/
......@@ -416,6 +421,7 @@ inline OpRegistry& OpRegistry::set_num_inputs(int32_t n) { // NOLINT(*)
inline OpRegistry& OpRegistry::set_attrs_type_key( // NOLINT(*)
const std::string& type_key) {
get()->attrs_type_key = type_key;
get()->attrs_type_index = Node::TypeKey2Index(type_key.c_str());
return *this;
}
......
......@@ -14,13 +14,15 @@ from . import libinfo
#----------------------------
if sys.version_info[0] == 3:
string_types = (str,)
numeric_types = (float, int, np.float32, np.int32)
integer_types = (int, np.int32)
numeric_types = integer_types + (float, np.float32)
# this function is needed for python3
# to convert ctypes.char_p .value back to python str
py_str = lambda x: x.decode('utf-8')
else:
string_types = (basestring,)
numeric_types = (float, int, long, np.float32, np.int32)
integer_types = (int, long, np.int32)
numeric_types = integer_types + (float, np.float32)
py_str = lambda x: x
......
# pylint: disable=wildcard-import, redefined-builtin
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler."""
from . import base
from . import ty
......@@ -19,6 +19,9 @@ from . import image
# Span
Span = base.Span
# Env
Environment = env.Environment
# Type
Type = ty.Type
TupleType = ty.TupleType
......@@ -40,3 +43,7 @@ Call = expr.Call
Let = expr.Let
If = expr.If
TupleGetItem = expr.TupleGetItem
# helper functions
var = expr.var
const = expr.const
......@@ -3,6 +3,7 @@
from __future__ import absolute_import as _abs
from .._ffi.node import NodeBase, register_node as _register_tvm_node
from . import _make
from . import _expr
NodeBase = NodeBase
......@@ -20,7 +21,19 @@ def register_relay_node(type_key=None):
return _register_tvm_node(type_key)
class RelayNode(NodeBase):
def astext(self):
"""Get the text format of the expression.
Returns
-------
text : str
The text format of the expression.
"""
return _expr._text_print(self)
@register_relay_node
class Span(NodeBase):
class Span(RelayNode):
def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global environment storing everything needed to interpret or compile a Relay program."""
from .base import register_relay_node, NodeBase
from .base import register_relay_node, RelayNode
from . import _make
from . import _env
@register_relay_node
class Environment(NodeBase):
class Environment(RelayNode):
"""The global Relay environment containing functions,
options and more.
"""
......
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression nodes of Relay."""
from __future__ import absolute_import
from .base import NodeBase, register_relay_node
from . import _expr
import numpy as _np
from .base import RelayNode, register_relay_node
from . import _make
from . import ty as _ty
from .._ffi import base as _base, node as _node
from .. import nd as _nd
from .. import convert
class Expr(NodeBase):
class Expr(RelayNode):
"""The base type for all Relay expressions."""
@property
def checked_type(self):
......@@ -56,7 +60,7 @@ class Tuple(Expr):
@register_relay_node
class Var(Expr):
"""A local variable in Tvm.Relay.
"""A local variable in Relay.
Local variable can be used to declare input
arguments to a function, or intermediate variables.
......@@ -101,26 +105,26 @@ class Function(Expr):
params: List[tvm.relay.Var]
List of input parameters to the function.
ret_type: tvm.relay.Type
The return type annotation of the function.
body: tvm.relay.Expr
The body of the function.
ret_type: Optional[tvm.relay.Type]
The return type annotation of the function.
type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only
used in advanced usecase of template functions.
"""
def __init__(self,
params,
ret_type,
body,
ret_type=None,
type_params=None):
if type_params is None:
type_params = convert([])
self.__init_handle_by_constructor__(
_make.Function, params, ret_type, body, type_params)
_make.Function, params, body, ret_type, type_params)
@register_relay_node
......@@ -158,7 +162,7 @@ class Let(Expr):
Parameters
----------
var: tvm.relay.Var
variable: tvm.relay.Var
The local variable to be bound.
value: tvm.relay.Expr
......@@ -167,9 +171,9 @@ class Let(Expr):
body: tvm.relay.Expr
The body of the let binding.
"""
def __init__(self, var, value, body):
def __init__(self, variable, value, body):
self.__init_handle_by_constructor__(
_make.Let, var, value, body)
_make.Let, variable, value, body)
@register_relay_node
......@@ -208,4 +212,105 @@ class TupleGetItem(Expr):
self.__init_handle_by_constructor__(
_make.TupleGetItem, tuple_value, index)
debug_print = _expr._debug_print
class TupleWrapper(_node.NodeGeneric):
"""TupleWrapper.
This class is a Python wrapper for a Relay tuple of known size.
It allows for accessing the fields of the Relay tuple as though
it were a Python tuple.
Parameters
----------
tuple_value: tvm.relay.Expr
The input tuple
size: int
The size of the tuple.
"""
def __init__(self, tuple_value, size):
self.tuple_value = tuple_value
self.size = size
def asnode(self):
"""Returns the underlying Relay tuple if this wrapper is passed
as an argument to an FFI function."""
return self.tuple_value
def __getitem__(self, key):
return self.tuple_value.fields[key]
def __len__(self):
return len(self.tuple_value.fields)
def var(name_hint,
type_annotation=None,
shape=None,
dtype="float32"):
"""Create a new tvm.relay.Var.
This is a simple wrapper function that allows specify
shape and dtype directly.
Parameters
----------
name_hint: str
The name of the variable.
This name only acts as a hint, and is not used
for equality.
type_annotation: Optional[tvm.relay.Type, str]
The type annotation on the variable.
When type_annotation is a str, we will create a scalar variable.
shape: Optional[List[tvm.Expr]]
The shape of the tensor type.
dtype: str, optional
The data type of the tensor.
Examples
--------
.. code-block:: python
# The following 4 lines are equivalent to each other
x = tvm.relay.Var("x", tvm.relay.TensorType([1, 2]))
x = tvm.relay.var("x", tvm.relay.TensorType([1, 2]))
x = tvm.relay.var("x", shape=[1, 2])
x = tvm.relay.var("x", shape=[1, 2], dtype="float32")
# The following 2 lines are equivalent to each other.
y = tvm.relay.var("x", "float32")
y = tvm.relay.var("x", shape=(), dtype="float32")
"""
if type_annotation is not None and shape is not None:
raise ValueError("Can only specify either type_annotation or shape.")
if shape is not None:
type_annotation = _ty.TensorType(shape, dtype)
elif isinstance(type_annotation, str):
type_annotation = _ty.TensorType((), type_annotation)
return Var(name_hint, type_annotation)
def const(value, dtype=None):
"""Create a constant value.
Parameters
----------
value: Union[bool, int, float, numpy.ndarray, tvm.nd.NDArray]
The constant value.
dtype: str, optional
The data type of the value.
"""
if isinstance(value, _base.numeric_types):
value = _np.array(value, dtype=dtype)
elif isinstance(value, (bool, list)):
value = _np.array(value, dtype=dtype)
if isinstance(value, (_np.ndarray, _np.generic)):
value = _nd.array(value)
if not isinstance(value, _nd.NDArray):
raise ValueError("value has to be scalar or NDArray")
return Constant(value)
......@@ -11,32 +11,6 @@ from .expr import Expr, Constant, Let, Var, Function, If
from .env import Environment
class TupleWrapper(tvm._ffi.node.NodeGeneric):
"""TupleWrapper.
This class is a Python wrapper for a Relay tuple of known size.
It allows for accessing the fields of the Relay tuple as though
it were a Python tuple.
"""
def __init__(self, tuple_value, size):
self.tuple_value = tuple_value
self.size = size
def asnode(self):
"""Returns the underlying Relay tuple if this wrapper is passed
as an argument to an FFI function."""
return self.tuple_value
def __getitem__(self, key):
return self.tuple_value.fields[key]
def __len__(self):
return len(self.tuple_value.fields)
def _convert_to_value(arg, ctxt=tvm.cpu(0)):
# type: (Any, tvm.Context) -> tvm.nd.NDArray
"""Convert Python values into the appropriate types
......@@ -132,8 +106,8 @@ class PartialFunc(object):
"""Converts a PartialFunc into a :py:class:`~relay.Function`."""
return Function(
self.params,
self.ret_type,
self.body,
self.ret_type,
self.type_params)
#pylint: disable=invalid-name
......@@ -325,7 +299,7 @@ class IRBuilder(object):
def _on_exit():
bindings, _, _, ret_value = self.exit_scope()
exp = _mk_let(bindings, ret_value)
self.env.add(name, Function(params, ret_type, exp))
self.env.add(name, Function(params, exp, ret_type))
return WithScope(10, _on_exit)
......
"""Neural network operations."""
from __future__ import absolute_import as _abs
from tvm.relay.ir_builder import TupleWrapper
from ...expr import TupleWrapper
from . import _make
......@@ -145,7 +145,7 @@ def conv2d_transpose(data,
weight_layout, output_padding, out_dtype)
def softmax(data, axis):
def softmax(data, axis=1):
r"""Computes softmax.
.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
......@@ -158,7 +158,7 @@ def softmax(data, axis):
data: relay.Expr
The input data to the operator.
axis: int
axis: int, optional
The axis to sum over when computing softmax
Returns
......
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language."""
from enum import IntEnum
from .base import NodeBase, register_relay_node
from .base import RelayNode, register_relay_node
from . import _make
class Type(NodeBase):
class Type(RelayNode):
"""The base type for all Relay types."""
def __eq__(self, other):
......@@ -21,27 +21,25 @@ class Type(NodeBase):
"""Compares two Relay types by referential equality."""
return super().__eq__(other)
@register_relay_node
class TensorType(Type):
"""A concrete TensorType in Relay, see tvm/relay/type.h for more details.
"""A concrete TensorType in Relay.
This is the type assigned to tensor's with a known dype and shape. For
example a tensor of `float32` and `(5, 5)`.
"""
def __init__(self, shape, dtype):
"""Construct a tensor type.
Parameters
----------
shape: list of tvm.Expr
dtype: str
Parameters
----------
shape: List[tvm.Expr]
The shape of the Tensor
Returns
-------
tensor_type: The TensorType
"""
self.__init_handle_by_constructor__(_make.TensorType, shape, dtype)
dtype: str, optional
The content data type.
"""
def __init__(self, shape, dtype="float32"):
self.__init_handle_by_constructor__(
_make.TensorType, shape, dtype)
class Kind(IntEnum):
......
......@@ -17,11 +17,15 @@ namespace tvm {
template <typename FType>
class AttrFunctor;
#define ATTR_FUNCTOR_DEFAULT \
{ return VisitAttrDefault_(op, std::forward<Args>(args)...); }
#define ATTR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->Visit_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
return self->VisitAttr_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
}); \
// A functor for common attribute information.
......@@ -40,21 +44,21 @@ class AttrFunctor<R(const NodeRef& n, Args...)> {
* \param args Additional arguments.
* \return The result of the call
*/
virtual R Visit(const NodeRef& n, Args... args) {
virtual R VisitAttr(const NodeRef& n, Args... args) {
static FType vtable = InitVTable();
if (vtable.can_dispatch(n)) {
return vtable(n, this, std::forward<Args>(args)...);
} else {
return VisitDefault_(n, std::forward<Args>(args)...);
return VisitAttrDefault_(n.get(), std::forward<Args>(args)...);
}
}
virtual R Visit_(const ArrayNode* op, Args... args) = 0;
virtual R Visit_(const StrMapNode* op, Args... args) = 0;
virtual R Visit_(const ir::IntImm* op, Args... args) = 0;
virtual R Visit_(const ir::UIntImm* op, Args... args) = 0;
virtual R Visit_(const ir::FloatImm* op, Args... args) = 0;
virtual R Visit_(const ir::StringImm* op, Args... args) = 0;
virtual R VisitDefault_(const NodeRef& n, Args... args) = 0;
virtual R VisitAttr_(const ArrayNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const StrMapNode* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::IntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::UIntImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::FloatImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttr_(const ir::StringImm* op, Args... args) ATTR_FUNCTOR_DEFAULT;
virtual R VisitAttrDefault_(const Node* node, Args... args) = 0;
private:
// initialize the vtable.
......
......@@ -11,6 +11,10 @@ void DictAttrsNode::VisitAttrs(AttrVisitor* v) {
v->Visit("__dict__", &dict);
}
void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) {
v->Visit("__dict__", &dict);
}
void DictAttrsNode::InitByPackedArgs(
const runtime::TVMArgs& args, bool allow_unknown) {
for (int i = 0; i < args.size(); i += 2) {
......@@ -55,48 +59,48 @@ class AttrsEqualChecker :
if (!equal_) return false;
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
if (!this->Visit(lhs, rhs)) {
if (!this->VisitAttr(lhs, rhs)) {
equal_ = false;
}
return equal_;
}
bool VisitDefault_(const NodeRef& lhs, const NodeRef& other) final {
bool VisitAttrDefault_(const Node* lhs, const NodeRef& other) final {
if (lhs->derived_from<BaseAttrsNode>()) {
return static_cast<const BaseAttrsNode*>(lhs.get())->ContentEqual(other.get());
return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(other.get());
}
return lhs.same_as(other);
return lhs == other.get();
}
bool Visit_(const IntImm* lhs, const NodeRef& other) final {
bool VisitAttr_(const IntImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<IntImm>()) {
return lhs->value == rhs->value;
}
return false;
}
bool Visit_(const UIntImm* lhs, const NodeRef& other) final {
bool VisitAttr_(const UIntImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<UIntImm>()) {
return lhs->value == rhs->value;
}
return false;
}
bool Visit_(const FloatImm* lhs, const NodeRef& other) final {
bool VisitAttr_(const FloatImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<FloatImm>()) {
return lhs->value == rhs->value;
}
return false;
}
bool Visit_(const StringImm* lhs, const NodeRef& other) final {
bool VisitAttr_(const StringImm* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<StringImm>()) {
return lhs->value == rhs->value;
}
return false;
}
bool Visit_(const ArrayNode* lhs, const NodeRef& other) final {
bool VisitAttr_(const ArrayNode* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<ArrayNode>()) {
if (rhs->data.size() != lhs->data.size()) return false;
for (size_t i = 0; i < lhs->data.size(); ++i) {
......@@ -106,7 +110,7 @@ class AttrsEqualChecker :
return true;
}
bool Visit_(const StrMapNode* lhs, const NodeRef& other) final {
bool VisitAttr_(const StrMapNode* lhs, const NodeRef& other) final {
if (const auto* rhs = other.as<StrMapNode>()) {
if (rhs->data.size() != lhs->data.size()) return false;
for (const auto& kv : lhs->data) {
......@@ -127,38 +131,38 @@ class AttrContentHasher :
public:
size_t result_{0};
void VisitDefault_(const NodeRef& value) final {
void VisitAttrDefault_(const Node* value) final {
if (value->derived_from<BaseAttrsNode>()) {
Update(static_cast<const BaseAttrsNode*>(value.get())->ContentHash());
Update(static_cast<const BaseAttrsNode*>(value)->ContentHash());
} else {
Update(NodeHash()(value));
Update(NodeHash()(GetRef<NodeRef>(value)));
}
}
void Visit_(const IntImm* op) final {
void VisitAttr_(const IntImm* op) final {
Update(std::hash<int64_t>()(op->value));
}
void Visit_(const UIntImm* op) final {
void VisitAttr_(const UIntImm* op) final {
Update(std::hash<uint64_t>()(op->value));
}
void Visit_(const FloatImm* op) final {
void VisitAttr_(const FloatImm* op) final {
Update(std::hash<double>()(op->value));
}
void Visit_(const StringImm* op) final {
void VisitAttr_(const StringImm* op) final {
Update(std::hash<std::string>()(op->value));
}
void Visit_(const ArrayNode* op) final {
void VisitAttr_(const ArrayNode* op) final {
Update(op->data.size());
for (size_t i = 0; i < op->data.size(); ++i) {
this->Visit(NodeRef(op->data[i]));
this->VisitAttr(NodeRef(op->data[i]));
}
}
void Visit_(const StrMapNode* lhs) final {
void VisitAttr_(const StrMapNode* lhs) final {
using Entry = std::pair<std::string, NodePtr<Node> >;
std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
......@@ -166,7 +170,7 @@ class AttrContentHasher :
});
for (const Entry& kv : data) {
Update(std::hash<std::string>()(kv.first));
this->Visit(NodeRef(kv.second));
this->VisitAttr(NodeRef(kv.second));
}
}
......@@ -184,7 +188,7 @@ bool AttrsEqual::Equal(const NodeRef& lhs, const NodeRef& rhs) {
size_t AttrsHash::Hash(const NodeRef& node) {
if (!node.defined()) return 0;
AttrContentHasher hasher;
hasher.Visit(node);
hasher.VisitAttr(node);
return hasher.result_;
}
......
......@@ -208,6 +208,8 @@ class JSONAttrGetter : public AttrVisitor {
node_->type_key = node->type_key();
// sepcially handle global object
auto* f = dmlc::Registry<NodeFactoryReg>::Find(node_->type_key);
CHECK(f != nullptr)
<< "Node type \'" << node_->type_key << "\' is not registered in TVM";
if (f->fglobal_key != nullptr) {
node_->global_key = f->fglobal_key(node);
return;
......
......@@ -51,6 +51,8 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) {
return Span(n);
}
TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_API("relay._make.Span")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = SpanNode::make(args[0], args[1], args[2]);
......
/*!
* Copyright (c) 2018 by Contributors
* \file debug_printer.cc
* \brief A pretty printer for the Relay IR.
* As we had not determined a formal syntax yet, right now it is only for debug purpose.
*/
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/environment.h>
#include <tvm/relay/error.h>
#include <iostream>
#include <sstream>
#include <vector>
#include <unordered_map>
#include <string>
#include <vector>
#include <iostream>
#include "../pass/type_functor.h"
#include "doc.h"
namespace tvm {
namespace relay {
using namespace tvm::runtime;
Doc KindDocify(TypeParamNode::Kind k) {
switch (k) {
case TypeParamNode::kShapeVar:
return DocOfStr("ShapeVar");
case TypeParamNode::kShape:
return DocOfStr("Shape");
case TypeParamNode::kBaseType:
return DocOfStr("BaseType");
case TypeParamNode::kType:
return DocOfStr("Type");
default:
LOG(FATAL) << "unreachable code: case not handle in kind";
throw; // log fatal throw but compiler doesnt know
}
}
template<typename T>
std::vector<Doc> MapDocify(const tvm::Array<T>& arr, const std::function<Doc(const T&)>& f) {
std::vector<Doc> vec;
for (size_t i = 0; i < arr.size(); ++i) {
vec.push_back(f(arr[i]));
}
return vec;
}
template<typename T, typename Hash = std::hash<T>, typename Eq = std::equal_to<T>>
class Counter {
std::unordered_map<T, size_t, Hash, Eq> cnt_;
public:
Counter() = default;
Counter(const Counter&) = delete;
size_t operator()(const T& t) {
auto v = cnt_.count(t) == 0 ? 0 : cnt_.at(t) + 1;
cnt_[t] = v;
return v;
}
};
std::string Mangle(const std::string& str, size_t s) {
return str + "_" + std::to_string(s);
// return s == 0 ? str : str + "_" + std::to_string(s - 1);
// the above line look prettier but is dangerous:
// suppose we have x, x, x_0. mangling will give x, x_0, x_0!
// the save approach give x_0, x_1, x_0_1, and in fact never clash:
// stripping _([0-9]*) is invert of mangle under all circumstances.
// another problem is we need to prevent Var/TypeParam/GlobalVar clashing each other.
}
constexpr size_t indent = 2;
struct TypeParamName {
bool operator==(const TypeParamName&) const {
return true;
}
};
struct mhash {
size_t operator()(const ::tvm::relay::TypeParamName&) const noexcept {
return 0;
}
};
class TypeDocifier : private TypeFunctor<Doc(const Type& n)> {
Environment env;
Counter<TypeParamName, mhash> cnt;
std::unordered_map<TypeParam, Doc, NodeHash, NodeEqual> map;
std::vector<Doc> DocifyTypeArray(const tvm::Array<Type>& arr) {
return MapDocify<Type>(arr, [=](const Type& t) { return Docify(t); });
}
std::vector<Doc> DocifyTypeParam(const tvm::Array<TypeParam>& arr) {
return MapDocify<TypeParam>(arr, [=](const TypeParam& tp) {
return Docify(tp);
});
}
std::vector<Doc> DocifyTypeConstraint(const tvm::Array<TypeConstraint>& arr) {
return MapDocify<TypeConstraint>(arr, [=](const TypeConstraint& tc) { return Docify(tc); });
}
Doc VisitType_(const TensorTypeNode* t) final {
return DocOfStr("tensor");
}
Doc VisitType_(const TypeParamNode* p) final {
auto tp = GetRef<TypeParam>(p);
if (map.count(tp) == 0) {
auto name =
DocOfStr(Mangle("tp", cnt(TypeParamName())) +
std::string(":")) +
KindDocify(p->kind);
map.insert(std::pair<TypeParam, Doc>(tp, name));
}
return map.at(tp);
}
Doc Quantify(const tvm::Array<TypeParam>& tp, const Doc& d) {
if (tp.size() == 0) {
return d;
}
return Seq("forall", DocifyTypeParam(tp), ",") + Sep() + d;
}
Doc Constraint(const tvm::Array<TypeConstraint>& tc, const Doc& d) {
if (tc.size() == 0) {
return d;
}
return Seq("(", DocifyTypeConstraint(tc), ") =>") + Sep() + d;
}
Doc VisitType_(const FuncTypeNode* f) final {
auto inner = Seq("<", DocifyTypeArray(f->arg_types), ">") + Sep() +
DocOfStr("->") + Sep() + Docify(f->ret_type);
return Group(Quantify(f->type_params,
Constraint(f->type_constraints, inner)));
}
Doc VisitType_(const TypeRelationNode* r) final {
return DocOfStr("Relation") + Seq("(", DocifyTypeArray(r->args), ")");
}
Doc VisitType_(const TupleTypeNode* t) final {
return Seq("<", DocifyTypeArray(t->fields), ">");
}
Doc VisitType_(const IncompleteTypeNode* i) final {
return DocOfStr("_");
}
public:
TypeDocifier(const Environment& env) : env(env) { }
Doc Docify(const Type& t) { return t.get() ? (*this)(t) : DocOfStr("_"); }
};
class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
Environment env;
Counter<std::string> cnt;
std::unordered_map<Var, std::string, NodeHash, NodeEqual> map;
TypeDocifier td;
std::string VarName(const Var& v) {
if (map.count(v) == 0) {
map.insert(std::pair<Var, std::string>(v, Mangle(v->name_hint, cnt(v->name_hint))));
}
return map.at(v);
}
Doc TypeAnnotation(const Doc& d, const Type& t) {
// test for t being null. probably shouldnt has null. should talk to jared.
if (!t.get() || t.as<IncompleteTypeNode>()) {
return d;
} else {
return d + DocOfStr(":") + td.Docify(t);
}
}
std::vector<Doc> DocifyExprArray(const tvm::Array<Expr>& arr) {
std::vector<Doc> vec;
for (size_t i = 0; i < arr.size(); ++i) {
vec.push_back(Docify(arr[i]));
}
return vec;
}
std::vector<Doc> DocifyParamArray(const tvm::Array<Var>& arr) {
std::vector<Doc> vec;
for (Var param : arr) {
vec.emplace_back(TypeAnnotation(DocOfStr(VarName(param)),
param->type_annotation));
}
return vec;
}
Doc VisitExpr_(const ConstantNode* c) final {
return DocOfStr("some_constant");
}
Doc VisitExpr_(const TupleNode* t) final {
return Seq("<", DocifyExprArray(t->fields), ">");
}
Doc VisitExpr_(const VarNode* v) final {
return DocOfStr(VarName(GetRef<Var>(v)));
}
Doc VisitExpr_(const GlobalVarNode* g) final {
return DocOfStr(g->name_hint);
}
Doc VisitExpr_(const FunctionNode* f) final {
return Group(TypeAnnotation(Seq("(", DocifyParamArray(f->params), ")"), f->ret_type) + Sep() +
DocOfStr("=>") + Sep() +
Block(indent, "{", Docify(f->body), "}"));
}
Doc VisitExpr_(const CallNode* c) final {
return Docify(c->op) + Seq("<", DocifyExprArray(c->args), ">");
}
Doc VisitExpr_(const LetNode* l) final {
return Group(DocOfStr("let") + Sep() +
TypeAnnotation(Docify(l->var), l->var->type_annotation) + Sep() +
DocOfStr("=") + Sep() + Docify(l->value) + DocOfStr(";") + Endl() +
Docify(l->body));
}
Doc VisitExpr_(const IfNode* i) final {
return Group(DocOfStr("if") + Sep() + Docify(i->cond) + Sep() +
Block(indent, "{", Docify(i->true_branch), "}") + Sep() +
DocOfStr("else") + Sep() +
Block(indent, "{", Docify(i->false_branch), "}"));
}
Doc VisitExpr_(const OpNode* o) final {
return DocOfStr(o->name);
}
Doc VisitExpr_(const TupleGetItemNode* g) final {
return Docify(g->tuple) + DocOfStr(std::string(".") + std::to_string(g->index));
}
public:
ExprDocifier(const Environment& env) : env(env), td(env) { }
Doc Docify(const Expr& e) { return (*this)(e); }
};
Doc DocOfExpr(const Environment& env, const Expr& expr) {
ExprDocifier d(env);
return d.Docify(expr);
}
Doc DocOfType(const Environment& env, const Type& expr) {
TypeDocifier d(env);
return d.Docify(expr);
}
RDoc ExprRDoc(const Environment& env, const Expr& expr) {
return Layout(DocOfExpr(env, expr));
}
RDoc TypeRDoc(const Environment& env, const Type& expr) {
return Layout(DocOfType(env, expr));
}
std::ostream & DebugPrint(const Environment& env, const Expr& e, std::ostream& os) {
return os << ExprRDoc(env, e);
}
std::ostream & DebugPrint(const Environment& env, const Type& t, std::ostream& os) {
return os << TypeRDoc(env, t);
}
std::string PrintExpr(const Environment& env, const Expr& e) {
std::stringstream ss;
ss << ExprRDoc(env, e);
return ss.str();
}
std::string PrintType(const Environment& env, const Type& t) {
std::stringstream ss;
ss << TypeRDoc(env, t);
return ss.str();
}
TVM_REGISTER_API("relay._expr._debug_print")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[1];
if (x.as<TypeNode>()) {
*ret = PrintType(args[0], Downcast<Type>(x));
} else {
*ret = PrintExpr(args[0], Downcast<Expr>(x));
}
});
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file doc.h
* \brief A pretty printer DSL for constructing (Doc) and formatting (RDoc) documents.
* It is based heavily on Philip Wadler's "A prettier printer."
* See https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf
* for more details.
*
* Since the original paper uses call by value for efficiency, everything doc function is maximally lazy.
* You can probably yank speed by doing strict analysis and removing some Lazy (if this is bottleneck).
*/
#ifndef TVM_RELAY_IR_DOC_H_
#define TVM_RELAY_IR_DOC_H_
#include <tvm/relay/error.h>
#include <unordered_map>
#include <utility>
#include <string>
#include <functional>
#include <vector>
#include <memory>
#include <ostream>
#include <map>
namespace tvm {
namespace relay {
/*! \brief A Document represent structured text.
* beside having unstructured string, it capture different ways to compose them -
* line break, space, indentation, representation choice.
*/
struct Doc;
/*! \brief RDoc represent rendered document.
* all the high level detail on the document, such as indentation, choice, has been removed.
* there is only one single, straight forward way to print it.
*/
struct RDoc;
//! \brief Empty document
inline Doc Nil();
//! \brief Concatenate two documents
inline Doc App(const Doc& l, const Doc& r);
//! \brief Indent a document
inline Doc Nest(size_t width, const Doc& doc);
//! \brief Lift string to a document
inline Doc DocOfStr(const std::string& text);
//! \brief New line
inline Doc Endl();
//! \brief Remove all line break from the Document.
inline Doc Flatten(const Doc& d);
/*! \brief Choose between two possible layouts.
* assume Flatten(l) == Flatten(r), and l need to be more compact.
*/
inline Doc Choose(const Doc& l, const Doc& r);
//! \brief Use a single line if possible
inline Doc Group(const Doc& d);
//! \brief print an RDoc
inline std::ostream& operator<<(std::ostream& os, const RDoc& rdoc);
/*! \brief Joins a vector of documents with a given separator document
* \example Join(["a", "b, "c"], ", ") => "a, b, c"
* \param vec the vector of documents
* \param sep the separator between documents
*/
inline Doc Join(const std::vector<Doc>& vec, const Doc& sep);
/*! \brief Creates an indented block.
* \param indent the indentation size
* \param open the opening string
* \param body the body of the block
* \param close the closing string
*/
inline Doc Block(size_t indent, const std::string& open,
const Doc& body, const std::string& close);
/*! \brief Creates a comma-separated sequence with opening and closing strings.
* \param open the opening string
* \param body the body of the Block
* \param close the closing string
*/
inline Doc Seq(const std::string& open,
const std::vector<Doc>& body, const std::string& close);
//! \brief Either a space or a new line
inline Doc Sep();
/*! \brief Layout a document to a given width
* \param d the document to render
* \param width the line width
*/
inline RDoc Layout(const Doc& d, size_t width = 80);
// end of API, start of implementation
template<typename T>
struct LazyNode {
mutable std::function<T()> thunk;
explicit LazyNode(const std::function<T()>& thunk) : thunk(thunk) { }
};
//! \brief denote a value that will be computed (at most once) on need.
template<typename T>
struct Lazy {
std::shared_ptr<LazyNode<T> > lazy_node;
explicit Lazy(const std::function<T()>& thunk) :
lazy_node(std::make_shared<LazyNode<T>>(thunk)) { }
explicit Lazy(const T& value) : Lazy([=]() { return value; }) { }
explicit Lazy(const Lazy<Lazy<T>>& thunk) : Lazy([=]() { return thunk.get().get(); }) { }
// calculate the result.
// memoize it by replacing the thunk with a constant function which immediate return.
T get() const {
T res = lazy_node->thunk();
lazy_node->thunk = [=]() { return res; };
return res;
}
template<typename R>
Lazy<R> map(const std::function<R(const T&)>& func) const {
Lazy<T> self(*this);
return Lazy<R>([=]() -> R { return func(self.get()); });
}
};
struct NilNode;
struct AppNode;
struct NestNode;
struct TextNode;
struct LineNode;
struct ChoiceNode;
/*! \brief The inner representation of Doc.
* a doc represent structured text,
* and can be rendered onto screen while keeping the structure.
*/
struct DocNode {
/* a docnode is a union of the below node.
* exactly one of them will be non null.
* their meaning is denoted by the construction function of the same name.
* so for example, the meaning of AppNode is exactly a node construct by App.
*/
std::shared_ptr<NilNode> nil;
std::shared_ptr<AppNode> app;
std::shared_ptr<NestNode> nest;
std::shared_ptr<TextNode> text; // construct by DocOfStr
std::shared_ptr<LineNode> line;
std::shared_ptr<ChoiceNode> choice;
DocNode(std::shared_ptr<NilNode> nil,
std::shared_ptr<AppNode> app,
std::shared_ptr<NestNode> nest,
std::shared_ptr<TextNode> text,
std::shared_ptr<LineNode> line,
std::shared_ptr<ChoiceNode> choice) :
nil(nil),
app(app),
nest(nest),
text(text),
line(line),
choice(choice) { }
};
struct Doc {
Lazy<DocNode> doc;
explicit Doc(const DocNode& ed) : doc(ed) { }
explicit Doc(const Lazy<Doc>& ldoc) :
doc(ldoc.map<Lazy<DocNode> >([](const Doc& d){ return d.doc; })) { }
Doc operator+(const Doc& r) const {
return App(*this, r);
}
template<typename T>
Lazy<T> Match(
const std::function<T()>& nilf,
const std::function<T(const Doc&, const Doc&)>& appf,
const std::function<T(size_t, const Doc&)>& nestf,
const std::function<T(const std::string&)>& textf,
const std::function<T()>& linef,
const std::function<T(const Doc&, const Doc&)>& choicef) const;
};
struct NilNode { };
struct AppNode {
Doc left, right;
AppNode(const Doc& left, const Doc& right) : left(left), right(right) { }
};
struct NestNode {
size_t space;
Doc doc;
NestNode(size_t space, const Doc& doc) : space(space), doc(doc) { }
};
struct TextNode {
std::string text;
explicit TextNode(const std::string& text) : text(text) { }
};
struct LineNode { };
struct ChoiceNode {
Doc left, right;
ChoiceNode(const Doc& left, const Doc& right) : left(left), right(right) { }
};
template<typename T>
Lazy<T> Doc::Match(
const std::function<T()>& nilf,
const std::function<T(const Doc&, const Doc&)>& appf,
const std::function<T(size_t, const Doc&)>& nestf,
const std::function<T(const std::string&)>& textf,
const std::function<T()>& linef,
const std::function<T(const Doc&, const Doc&)>& choicef) const {
return doc.map<T>([=](const DocNode& d) {
if (d.nil) {
return nilf();
} else if (d.app) {
return appf(d.app->left, d.app->right);
} else if (d.nest) {
return nestf(d.nest->space, d.nest->doc);
} else if (d.text) {
return textf(d.text->text);
} else if (d.line) {
return linef();
} else {
return choicef(d.choice->left, d.choice->right);
}
});
}
//! \brief Empty document
inline Doc Nil() {
return Doc(DocNode(std::make_shared<NilNode>(), nullptr, nullptr, nullptr, nullptr, nullptr));
}
//! \brief Concatenate two documents
inline Doc App(const Doc& l, const Doc& r) {
return Doc(DocNode(
nullptr,
std::make_shared<AppNode>(l, r),
nullptr,
nullptr,
nullptr,
nullptr));
}
//! \brief Indent a document
inline Doc Nest(size_t width, const Doc& doc) {
auto x = std::make_shared<NestNode>(width, doc);
return Doc(DocNode(
nullptr,
nullptr,
std::make_shared<NestNode>(width, doc),
nullptr,
nullptr,
nullptr));
}
//! \brief Lift string to a document
inline Doc DocOfStr(const std::string& text) {
return Doc(DocNode(nullptr, nullptr, nullptr,
std::make_shared<TextNode>(text), nullptr, nullptr));
}
//! \brief New line
inline Doc Endl() {
return Doc(DocNode(nullptr, nullptr, nullptr, nullptr, std::make_shared<LineNode>(), nullptr));
}
/*! \brief Choose between two possible layouts.
* assume Flatten(l) == Flatten(r), and l need to be more compact.
*/
inline Doc Choose(const Doc& l, const Doc& r) {
return Doc(DocNode(
nullptr,
nullptr,
nullptr,
nullptr,
nullptr,
std::make_shared<ChoiceNode>(l, r)));
}
//! \brief Remove new line from the whole document.
inline Doc Flatten(const Doc& d) {
return Doc(d.Match<Doc>(
[]() { return Nil(); },
[](const Doc& l, const Doc& r) { return Flatten(l) + Flatten(r); },
[](size_t space, const Doc& doc) { return Flatten(doc); },
[](const std::string& str) { return DocOfStr(str); },
[]() { return DocOfStr(" "); },
[](const Doc& l, const Doc& r) { return Flatten(l); }));
}
//! \brief Use a single line if possible
inline Doc Group(const Doc& d) {
return Choose(Flatten(d), d);
}
struct RNilNode;
struct RTextNode;
struct RLineNode;
struct RDocNode {
std::shared_ptr<RNilNode> rnil;
std::shared_ptr<RTextNode> rtext;
std::shared_ptr<RLineNode> rline;
RDocNode(const std::shared_ptr<RNilNode>& rnil,
const std::shared_ptr<RTextNode>& rtext,
const std::shared_ptr<RLineNode>& rline) :
rnil(rnil), rtext(rtext), rline(rline) { }
};
/*! \brief RDoc represent rendered document.
* all the high level detail on the document, such as indentation, alternative, has been removed.
* there is only one single, straight forward way to print it.
*/
struct RDoc {
Lazy<RDocNode> doc;
explicit RDoc(const RDocNode& d) : doc(d) { }
explicit RDoc(const Lazy<RDoc>& ldoc) :
doc(ldoc.map<Lazy<RDocNode>>([](const RDoc& d){ return d.doc; })) { }
template<typename T>
Lazy<T> Match(
const std::function<T()> &rnilf,
const std::function<T(const std::string&, const RDoc&)>& rtextf,
const std::function<T(size_t, const RDoc&)>& rlinef) const;
};
inline std::ostream& operator<<(std::ostream& os, const RDoc& rdoc) {
return *rdoc.Match<std::ostream*>(
[&]() { return & os; },
[&](const std::string& text, const RDoc& r) {
return & (os << text << r);
},
[&](size_t space, const RDoc& r) {
return & (os << std::endl << std::string(space, ' ') << r);
}).get();
}
struct RNilNode { };
struct RTextNode {
std::string text;
RDoc rest;
RTextNode(const std::string& text, const RDoc& rest) : text(text), rest(rest) { }
};
struct RLineNode {
size_t space;
RDoc rest;
RLineNode(size_t space, const RDoc& rest) : space(space), rest(rest) { }
};
//! \brief Empty RDoc
inline RDoc RNil() { return RDoc(RDocNode(std::make_shared<RNilNode>(), nullptr, nullptr)); }
//! \brief RDoc that begin with std::string
inline RDoc RText(const std::string& text, const RDoc& rest) {
return RDoc(RDocNode(nullptr, std::make_shared<RTextNode>(text, rest), nullptr));
}
//! \brief RDoc that begin with a new line, followed by space
inline RDoc RLine(size_t space, const RDoc& rest) {
return RDoc(RDocNode(nullptr, nullptr, std::make_shared<RLineNode>(space, rest)));
}
template<typename T>
Lazy<T> RDoc::Match(
const std::function<T()>& rnilf,
const std::function<T(const std::string&, const RDoc&)>& rtextf,
const std::function<T(size_t, const RDoc&)>& rlinef) const {
return doc.map<T>([=](const RDocNode& rdoc) {
if (rdoc.rnil) {
return rnilf();
} else if (rdoc.rtext) {
return rtextf(rdoc.rtext->text, rdoc.rtext->rest);
} else {
return rlinef(rdoc.rline->space, rdoc.rline->rest);
}
});
}
template<typename T>
struct List;
template<typename T>
struct EagerList {
const std::shared_ptr<std::pair<T, List<T>>> cons;
};
//! \brief lazy list
template<typename T>
struct List {
Lazy<EagerList<T> > l;
List() : l([]() { return EagerList<T>({nullptr}); }) { }
List(const T& t, const List<T>& l) :
l([=]() { return EagerList<T>({std::make_shared<std::pair<T, List<T>>>(t, l)}); }) { }
template<typename R>
Lazy<R> Match(const std::function<R()>& nullf,
const std::function<R(const T&, const List<T>&)>& consf) const {
return l.template map<R>([=](const EagerList<T>& l) {
if (l.cons) {
return consf(l.cons->first, l.cons->second);
} else {
return nullf();
}
});
}
};
//! \brief Does x fit into line of size w?
inline bool Fits(int w, const RDoc& x) {
return (w >= 0) && x.Match<bool>(
[]() { return true; },
[=](const std::string& s, const RDoc& x) { return Fits(w - s.size(), x); },
[](size_t space, const RDoc& x) { return true; }).get();
}
//! \brief Choose the one that fits best.
inline RDoc Better(size_t w, size_t k, const RDoc& x, const RDoc& y) {
return Fits(w-k, x) ? x : y;
}
typedef std::pair<size_t/*indent size*/, Doc> best_arg;
inline RDoc Best(size_t w/*wrap width*/, size_t k/*space used*/,
const List<best_arg>& l/*to be rendered*/) {
return RDoc(l.Match<RDoc>(
[]() { return RNil(); },
[=](const best_arg& p, const List<best_arg>& z) {
return RDoc(p.second.Match<RDoc>(
[=]() { return Best(w, k, z); },
[=](const Doc& x, const Doc& y) {
return Best(
w,
k,
List<best_arg>(best_arg(p.first, x), List<best_arg>(best_arg(p.first, y), z))); },
[=](size_t j, const Doc& x) {
return Best(w, k, List<best_arg>(best_arg(p.first + j, x), z)); },
[=](const std::string& text) { return RText(text, Best(w, k + text.size(), z)); },
[=]() { return RLine(p.first, Best(w, p.first, z)); },
[=](const Doc& x, const Doc& y) {
return Better(
w,
k,
Best(w, k, List<best_arg>(best_arg(p.first, x), z)),
Best(w, k, List<best_arg>(best_arg(p.first, y), z))); }));
}));
}
/*! \brief Joins a vector of documents with a given separator document
* \example Join(["a", "b, "c"], ", ") => "a, b, c"
* \param vec the vector of documents
* \param sep the separator between documents
*/
inline Doc Join(const std::vector<Doc>& vec, const Doc& sep) {
// https://www.safaribooksonline.com/library/view/c-cookbook/0596007612/ch04s09.html
Doc output = Nil();
for (auto p = vec.begin(); p != vec.end(); ++p) {
output = output + *p;
if (p != vec.end() - 1) {
output = output + sep;
}
}
return output;
}
/*! \brief Creates an indented block.
* \param indent the indentation size
* \param open the opening string
* \param body the body of the block
* \param close the closing string
*/
inline Doc Block(size_t indent, const std::string& open,
const Doc& body, const std::string& close) {
return DocOfStr(open) + Nest(indent, Endl() + body) + Endl() + DocOfStr(close);
}
/*! \brief Creates a comma-separated sequence with opening and closing strings.
* \param open the opening string
* \param body the body of the Block
* \param close the closing string
*/
inline Doc Seq(const std::string& open,
const std::vector<Doc>& body, const std::string& close) {
return Group(DocOfStr(open) +
Nest(open.size(), Join(body, DocOfStr(",") + Endl())) +
DocOfStr(close));
}
//! \brief Either a space or a new line
inline Doc Sep() {
return Choose(DocOfStr(" "), Endl());
}
/*! \brief Layout a document to a given width
* \param d the document to render
* \param width the line width
*/
inline RDoc Layout(const Doc& d, size_t width) {
return Best(width, 0, List<best_arg>(best_arg(0, d), List<best_arg>()));
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_IR_DOC_H_
......@@ -100,6 +100,8 @@ void EnvironmentNode::Merge(const Environment &env) {
}
}
TVM_REGISTER_NODE_TYPE(EnvironmentNode);
TVM_REGISTER_API("relay._make.Environment")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EnvironmentNode::make(args[0]);
......
......@@ -17,6 +17,8 @@ Constant ConstantNode::make(runtime::NDArray data) {
return Constant(n);
}
TVM_REGISTER_NODE_TYPE(ConstantNode);
TVM_REGISTER_API("relay._make.Constant")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ConstantNode::make(args[0]);
......@@ -44,6 +46,8 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
return Tuple(n);
}
TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_API("relay._make.Tuple")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TupleNode::make(args[0]);
......@@ -61,6 +65,8 @@ Var VarNode::make(std::string name_hint, Type type_annotation) {
return Var(n);
}
TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = VarNode::make(args[0], args[1]);
......@@ -82,6 +88,8 @@ GlobalVar GlobalVarNode::make(std::string name_hint) {
return GlobalVar(n);
}
TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_API("relay._make.GlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = GlobalVarNode::make(args[0]);
......@@ -94,13 +102,13 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
Function FunctionNode::make(tvm::Array<Var> params,
Type ret_type,
Expr body,
Type ret_type,
tvm::Array<TypeParam> type_params) {
NodePtr<FunctionNode> n = make_node<FunctionNode>();
n->params = std::move(params);
n->ret_type = std::move(ret_type);
n->body = std::move(body);
n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params);
return Function(n);
}
......@@ -113,6 +121,8 @@ FuncType FunctionNode::func_type_annotation() const {
return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {});
}
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_API("relay._make.Function")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FunctionNode::make(args[0], args[1], args[2], args[3]);
......@@ -135,6 +145,8 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
return Call(n);
}
TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_API("relay._make.Call")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CallNode::make(args[0], args[1], args[2], args[3]);
......@@ -154,6 +166,8 @@ Let LetNode::make(Var var, Expr value, Expr body) {
return Let(n);
}
TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_API("relay._make.Let")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = LetNode::make(args[0], args[1], args[2]);
......@@ -173,6 +187,8 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
return If(n);
}
TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IfNode::make(args[0], args[1], args[2]);
});
......@@ -190,6 +206,8 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
return TupleGetItem(n);
}
TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleGetItemNode::make(args[0], args[1]);
});
......
......@@ -92,7 +92,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return FunctionNode::make(params, ret_type, body, ty_params);
return FunctionNode::make(params, body, ret_type, ty_params);
}
}
......
/*!
* Copyright (c) 2018 by Contributors
* \file text_printer.cc
* \brief Text printer to print relay in text form.
*/
#include <tvm/relay/environment.h>
#include <tvm/relay/expr_functor.h>
#include <sstream>
#include "../pass/type_functor.h"
#include "../../lang/attr_functor.h"
namespace tvm {
namespace relay {
/*!
* \brief the text value used in text printer.
* Defined as a struct for future compatibility reason
*/
struct TextValue {
/*! \brief The str representation */
std::string name;
// constructor
TextValue() {}
// constructor
explicit TextValue(std::string name) : name(name) {}
};
// operator overloading
inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NOLINT(*)
return os << val.name;
}
/*!
* \brief Meta data context for TextPrinter.
*
* This is an important part to enable bi-directional serializability.
* We use tvm's Node system to build the current IR.
* It can be hard to design a text format for all the possible nodes
* as the set of nodes can grow when we do more extensions.
*
* Instead of trying to design readable text format for every nodes,
* we support a meta-data section in the text format.
* We allow the text format to refer to a node in the meta-data section.
*
* The meta-data section is a json serialized string of an Array<NodeRef>.
* Each element in the meta-data section can be referenced by the text format.
* Each meta data node is printed in the following format.
*
* meta.<type-key-of-node>(<index-in-meta-section>)
*
* Specifically, consider the following IR(constructed by python).
*
* \code
*
* n = tvm.var("n")
* x = tvm.relay.var("x", shape=(n, 1))
* f = tvm.relay.Function([x], x)
* print(f.astext())
*
* \endcode
*
* The corresponding text format is shown in the following code block.
*
* \code
*
* function(%x: Tensor[(meta.Variable(id=0),), float32]) {
* %x
* }
* # Meta data section is a json-serialized string
* # of the following array.
* # [tvm.var("n")]
*
* \endcode
*
* Note that we store tvm.var("n") in the meta data section.
* Since it is stored in the index-0 in the meta-data seciton,
* we print it as meta.Variable(0).
*
* The text parser can recover this object by loading from the corresponding
* location in the meta data section.
*
* This is is a design trade-off.
* It allows us to embedded any meta-data in the text format,
* while still being able to tweak the text part of the printed IR easily.
*/
class TextMetaDataContext {
public:
/*!
* \brief Get text representation of meta node.
* \param node The node to be converted to meta node.
* \return A string representation of the meta node.
*/
std::string GetMetaNode(const NodeRef& node) {
std::ostringstream os;
auto it = meta_index_.find(node);
int64_t index;
if (it != meta_index_.end()) {
index = it->second;
} else {
index = static_cast<int64_t>(meta_data_.size());
meta_data_.push_back(node);
meta_index_[node] = index;
}
os << "meta." << node->type_key() << "(id=" << index << ")";
return os.str();
}
/*!
* \brief Get the metadata section in json format.
* \return the meta datastring.
*/
std::string GetMetaSection() const {
if (meta_data_.size() == 0) return std::string();
return SaveJSON(Array<NodeRef>(meta_data_));
}
private:
/*! \brief additional metadata stored in TVM json format */
std::vector<NodeRef> meta_data_;
/*! \brief map from meta data into its index */
std::unordered_map<NodeRef, int64_t, NodeHash, NodeEqual> meta_index_;
};
class TextPrinter :
public ExprFunctor<TextValue(const Expr&)> ,
public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*)
public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*)
public:
/*!
* \brief Print a node to string.
* \param node.
* \return The string representation.
*/
std::string Print(const NodeRef& node) {
if (node.as<FunctionNode>()) {
this->PrintFunc(Downcast<Function>(node));
} else if (node.as<EnvironmentNode>()) {
this->PrintEnv(Downcast<Environment>(node));
} else if (node.as_derived<TypeNode>()) {
this->PrintType(Downcast<Type>(node), stream_);
} else if (node.as_derived<ExprNode>()) {
this->PrintExpr(Downcast<Expr>(node));
} else {
stream_ << node;
}
std::string meta_json = meta_.GetMetaSection();
if (meta_json.length() != 0) {
// append meta data in the end.
stream_ << "# meta data\n"
<< "r\"\"\"\n"
<< meta_json << "\n"
<< "\"\"\"";
}
return stream_.str();
}
void PrintFunc(const Function& func) {
this->PrintFuncInternal("function", func);
stream_ << "\n";
}
void PrintEnv(const Environment& env) {
int counter = 0;
for (const auto& kv : env->functions) {
std::ostringstream os;
if (counter++ != 0) {
stream_ << "\n";
}
os << "def @" << kv.first->name_hint;
this->PrintFuncInternal(os.str(), kv.second);
stream_ << "\n";
}
}
void PrintExpr(const Expr& expr) {
TextValue val = GetValue(expr);
stream_ << val << "\n";
}
/*!
* \brief Get text representation of expr.
*
* This function may generate additional instructions
* in order to compute the final result id of expr.
*
* When trying to recursively print out an Expr.
* The caller should always call GetValue of its children first.
* Then the caller can print out to stream_ using the obtained value.
*
* This is to avoid the call of subsequent GetValue print out
* additional instructions which get mixed with the partial instruction
* printed by the caller.
*
* \param expr The input expression.
* \return The text value of Expr.
*/
TextValue GetValue(const Expr& expr) {
auto it = memo_.find(expr);
if (it != memo_.end()) return it->second;
TextValue val = this->VisitExpr(expr);
memo_[expr] = val;
return val;
}
//------------------------------------
// Overload of Expr printing functions
//------------------------------------
TextValue VisitExpr_(const ConstantNode* op) final {
// Print out simple scalar directly.
if (op->is_scalar()) {
std::ostringstream os;
DataType dtype = TVMType2Type(op->data->dtype);
CHECK_EQ(op->data->ctx.device_type, kDLCPU);
if (dtype == Int(32)) {
return ConstScalar(dtype, static_cast<const int32_t*>(op->data->data));
} else if (dtype == Int(64)) {
return ConstScalar(dtype, static_cast<const int64_t*>(op->data->data));
} else if (dtype == Float(32)) {
return ConstScalar(dtype, static_cast<const float*>(op->data->data));
} else if (dtype == Float(64)) {
return ConstScalar(dtype, static_cast<const double*>(op->data->data));
}
}
// default fall-back, record it as meta node.
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = " << meta_.GetMetaNode(GetRef<NodeRef>(op));
this->PrintEndInst("\n");
return id;
}
TextValue VisitExpr_(const TupleNode* op) final {
std::vector<TextValue> fields;
for (Expr field : op->fields) {
fields.push_back(GetValue(field));
}
// NOTE: always recursively visit to get ids,
// before print out the current line
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = (";
for (size_t i = 0; i < fields.size(); ++i) {
stream_ << fields[i];
if (i + 1 != fields.size()) {
stream_ << ", ";
}
}
stream_ << ')';
this->PrintEndInst("\n");
return id;
}
TextValue VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
// This is an unbounded var.
TextValue val = AllocVarName(var);
this->PrintIndent();
stream_ << "free_var ";
this->PrintVarDecl(var, stream_);
this->PrintEndInst("\n");
return val;
}
TextValue VisitExpr_(const GlobalVarNode* op) final {
return TextValue('@' + op->name_hint);
}
TextValue VisitExpr_(const FunctionNode* op) final {
TextValue id = AllocTempVar();
std::ostringstream os;
os << id << " = function";
this->PrintFuncInternal(os.str(), GetRef<Function>(op));
this->PrintEndInst("\n");
return id;
}
TextValue VisitExpr_(const CallNode* op) final {
// TODO(tqchen, M.K.): support generic call
// possibly through meta-data
CHECK_EQ(op->type_args.size(), 0U)
<< "generic call not yet supported";
TextValue call_op = GetValue(op->op);
std::vector<TextValue> args;
for (Expr arg : op->args) {
args.emplace_back(GetValue(arg));
}
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = " << call_op << "(";
for (size_t i = 0; i < args.size(); ++i) {
stream_ << args[i];
if (i + 1 != args.size()) {
stream_ << ", ";
}
}
this->PrintCallAttrs(op->op, op->attrs, stream_);
stream_ << ")";
this->PrintEndInst("");
this->PrintOptionalInfo(GetRef<Expr>(op));
stream_ << '\n';
return id;
}
TextValue VisitExpr_(const LetNode* op) final {
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = ";
this->PrintScope(GetRef<Expr>(op));
this->PrintEndInst("\n");
return id;
}
TextValue VisitExpr_(const IfNode* op) final {
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = ";
this->PrintScope(GetRef<Expr>(op));
this->PrintEndInst("\n");
return id;
}
TextValue VisitExpr_(const OpNode* op) final {
return TextValue(op->name);
}
TextValue VisitExpr_(const TupleGetItemNode* op) final {
TextValue tuple = GetValue(op->tuple);
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = " << tuple << "[" << op->index << "]";
this->PrintEndInst("\n");
return id;
}
/*!
* \brief Print the type to os
* \param type The type to be printed.
* \param os The output type.
*/
void PrintType(const Type& type, std::ostream& os) { // NOLINT(*)
this->VisitType(type, os);
}
//------------------------------------
// Overload of Expr printing functions
//------------------------------------
void VisitType_(const TensorTypeNode* node, std::ostream& os) final { // NOLINT(*)
// scalar type
if (node->shape.size() == 0) {
os << runtime::TVMType2String(Type2TVMType(node->dtype));
return;
}
os << "Tensor[(";
for (size_t i = 0; i < node->shape.size(); ++i) {
this->PrintAttr(node->shape[i], os);
if (i + 1 != node->shape.size()) {
os << ", ";
}
}
// conform to python tuple format (1,)
if (node->shape.size() == 1) {
os << ",";
}
os << "), " << runtime::TVMType2String(Type2TVMType(node->dtype)) << "]";
}
void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*)
// by default always print as meta-data
os << meta_.GetMetaNode(GetRef<NodeRef>(node));
}
/*!
* \brief Print an attribute value to os.
* \param value The value to be printed.
* \param os The output type.
*/
void PrintAttr(const NodeRef& value, std::ostream& os) { // NOLINT(*)
this->VisitAttr(value, os);
}
//------------------------------------
// Overload of Attr printing functions
//------------------------------------
void VisitAttr_(const ArrayNode* op, std::ostream& os) final { // NOLINT(*)
os << "[";
for (size_t i = 0; i < op->data.size(); ++i) {
this->PrintAttr(NodeRef(op->data[i]), os);
if (i + 1 != op->data.size()) {
os << ", ";
}
}
os << "]";
}
void VisitAttrDefault_(const Node* op, std::ostream& os) final { // NOLINT(*)
os << meta_.GetMetaNode(GetRef<NodeRef>(op));
}
void VisitAttr_(const ir::IntImm* op, std::ostream& os) final { // NOLINT(*)
this->PrintConstScalar(op->type, &(op->value), os);
}
void VisitAttr_(const ir::UIntImm* op, std::ostream& os) final { // NOLINT(*)
this->PrintConstScalar(op->type, &(op->value), os);
}
void VisitAttr_(const ir::FloatImm* op, std::ostream& os) final { // NOLINT(*)
this->PrintConstScalar(op->type, &(op->value), os);
}
void VisitAttr_(const ir::StringImm* op, std::ostream& os) final { // NOLINT(*)
this->PrintString(op->value, os);
}
protected:
/*!
* \brief Print attributes after call.
* \param op The operator to be called.
* \param attrs The attributes.
* \param os The output stream.
*/
void PrintCallAttrs(const Expr& op, const Attrs& attrs, std::ostream& os); // NOLINT(*)
/*!
* \brief Print the a new scopr.
* \param body The body.
*/
void PrintScope(Expr body) {
stream_ << "{\n";
int sid = this->BeginScope();
this->PrintScopeBody(body);
this->EndScope(sid);
this->PrintIndent();
stream_ << "}";
}
/*!
* \brief Print the body of a new scope without {}
*
* This function will keep printing continuous sequence
* of let/if scope without introducing a new scope in the text.
*
* \param body The body.
*/
void PrintScopeBody(Expr body) {
if (const LetNode* let = body.as<LetNode>()) {
TextValue value = GetValue(let->value);
AllocVarName(let->var);
// let var = value;
this->PrintIndent();
stream_ << "let ";
this->PrintVarDecl(let->var, stream_);
stream_ << " = " << value;
this->PrintEndInst("\n");
this->PrintScopeBody(let->body);
} else if (const IfNode* ifnode = body.as<IfNode>()) {
TextValue cond = GetValue(ifnode->cond);
this->PrintIndent();
stream_ << "if (" << cond << ") ";
this->PrintScope(ifnode->true_branch);
this->PrintIndent();
stream_ << "else ";
this->PrintScope(ifnode->false_branch);
this->PrintEndInst("\n");
} else {
TextValue value = GetValue(body);
this->PrintIndent();
stream_ << value;
this->PrintEndInst("\n");
}
}
/*!
* \brief Internal function to print a function argument list and its body.
* \param prefix The prefix before argument list.
* \param fn The function to be printed.
*/
void PrintFuncInternal(std::string prefix, const Function& fn) {
// TODO(tqchen, M.K.) support generic function
// Possibly through meta-data
CHECK_EQ(fn->type_params.size(), 0U)
<< "generic fn not yet supported";
this->PrintIndent();
stream_ << prefix << "(";
size_t decl_indent = prefix.length() + 1;
for (size_t i = 0; i < fn->params.size(); ++i) {
if (i != 0) {
this->PrintIndent(decl_indent);
}
AllocVarName(fn->params[i]);
this->PrintVarDecl(fn->params[i], stream_);
if (i + 1 != fn->params.size()) {
stream_ << ",\n";
}
}
stream_ << ") ";
if (fn->ret_type.defined()) {
stream_ << " -> ";
this->PrintType(fn->ret_type, stream_);
}
this->PrintScope(fn->body);
}
/*!
* \brief Print additional info about expr in comment.
* \param expr The expression.
*/
void PrintOptionalInfo(const Expr& expr) {
// additional information in comment.
if (expr->checked_type_.defined()) {
stream_ << " # ty=";
this->PrintType(expr->checked_type(), stream_);
}
}
/*!
* \brief print var_name[:type]
* \param var The variable to be printed
* \param os The output stream
*/
void PrintVarDecl(const Var& var, std::ostream& os) { // NOLINT(*)
TextValue v = GetValue(var);
os << v;
if (var->type_annotation.defined()) {
os << ": ";
this->PrintType(var->type_annotation, os);
}
}
/*!
* \brief Get a constant scalar value.
* \param dtype The data type.
* \param data The pointer to the data.
* \tparam T the content data type holding the data.
*/
template<typename T>
TextValue ConstScalar(DataType dtype, const T* data) {
std::ostringstream os;
PrintConstScalar(dtype, data, os);
return TextValue(os.str());
}
/*!
* \brief special method to print out const scalar
* \param dtype The data type
* \param data The pointer to hold the data.
* \param os The output stream.
*/
template<typename T>
void PrintConstScalar(DataType dtype, const T* data, std::ostream& os) { // NOLINT(*)
if (dtype == Int(32)) {
os << data[0];
} else if (dtype == Float(32)) {
os << data[0] << 'f';
} else if (dtype == Bool()) {
PrintBool(data[0] != 0, os);
} else {
os << dtype << "(" << data[0] << ")";
}
}
/*!
* \brief Print constant bool value.
* \param value The value to be printed.
* \param os The output stream
*/
void PrintBool(bool value, std::ostream& os) { // NOLINT(*)
if (value) {
os << "True";
} else {
os << "False";
}
}
/*!
* \brief Print constant string.
* \param value The value to be printed.
* \param os The output stream
*/
void PrintString(const std::string& value, std::ostream& os) { // NOLINT(*)
// TODO(M.K.): add escape.
os << "\"" << value << "\"";
}
/*!
* \brief get a unique name with the corresponding prefix
* \param prefix The prefix of the name
* \return The returned name.
*/
std::string GetUniqueName(std::string prefix) {
auto it = name_alloc_map_.find(prefix);
if (it != name_alloc_map_.end()) {
while (true) {
std::ostringstream os;
os << prefix << (++it->second);
std::string name = os.str();
if (name_alloc_map_.count(name) == 0) {
prefix = name;
break;
}
}
}
name_alloc_map_[prefix] = 0;
return prefix;
}
/*!
* \brief mark the beginning of a new scope
* \return The scope id.
*/
int BeginScope() {
int sid = static_cast<int>(scope_valid_.size());
scope_valid_.push_back(true);
indent_ += 2;
return sid;
}
/*!
* \brief mark the end of an old scope.
* \param scope_id The scope id to be ended.
*/
void EndScope(int scope_id) {
scope_valid_[scope_id] = false;
indent_ -= 2;
}
/*!
* \brief Print the indent to the stream.
* \param more_indent More indentation besides the current one.
*/
void PrintIndent(int64_t more_indent = 0) {
for (int i = 0; i < indent_ + more_indent; ++i) {
stream_ << ' ';
}
}
/*!
* \brief print end of the line.
*/
void PrintEndInst(const char* suffix) {
stream_ << suffix;
}
/*!
* \brief Allocate temporary value
* \return A new text value.
*/
TextValue AllocTempVar() {
std::ostringstream os;
os << '%' << temp_var_counter_++;
return TextValue(os.str());
}
/*!
* \brief Allocate name to a variable.
* \param var The input variable.
* \return The corresponding name.
*/
TextValue AllocVarName(const Var& var) {
std::string name = GetUniqueName('%' + var->name_hint);
TextValue val(name);
CHECK(!memo_.count(var));
memo_[var] = val;
return val;
}
private:
class AttrPrinter;
friend class AttrPrinter;
/*! \brief meta data context */
TextMetaDataContext meta_;
/*! \brief Check whether scope is still valid */
std::vector<bool> scope_valid_;
/*! \brief The current indentation value */
int indent_{0};
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief Map from expression to its text value */
std::unordered_map<Expr, TextValue, NodeHash, NodeEqual> memo_;
/*! \brief counter of temporary variable */
int64_t temp_var_counter_{0};
/*! \brief Output stream */
std::ostringstream stream_;
};
/*!
* \brief Attribute printer which prints the attributes in the call.
*/
class TextPrinter::AttrPrinter: public AttrVisitor {
public:
AttrPrinter(std::ostream& stream, TextPrinter* parent) // NOLINT(*)
: stream_(stream), parent_(parent) {}
void Visit(const char* key, double* value) final {
PrintSep();
stream_ << key << "=" << value[0];
}
void Visit(const char* key, int64_t* value) final {
PrintSep();
stream_ << key << "=" << value[0];
}
void Visit(const char* key, uint64_t* value) final {
PrintSep();
stream_ << key << "=" << value[0];
}
void Visit(const char* key, int* value) final {
PrintSep();
stream_ << key << "=" << value[0];
}
void Visit(const char* key, bool* value) final {
PrintSep();
stream_ << key << "=";
parent_->PrintBool(value[0], stream_);
}
void Visit(const char* key, std::string* value) final {
PrintSep();
stream_ << key << "=";
parent_->PrintString(value[0], stream_);
}
void Visit(const char* key, void** value) final {
LOG(FATAL) << "do not allow void as argument";
}
void Visit(const char* key, DataType* value) final {
PrintSep();
stream_ << key << "=";
parent_->PrintString(runtime::TVMType2String(Type2TVMType(value[0])), stream_);
}
void Visit(const char* key, NodeRef* value) final {
PrintSep();
stream_ << key << "=";
parent_->PrintAttr(value[0], stream_);
}
void Visit(const char* key, runtime::NDArray* value) final {
LOG(FATAL) << "do not allow NDarray as argument";
}
private:
void PrintSep() {
stream_ << ", ";
}
std::ostream& stream_; // NOLINT(*)
TextPrinter* parent_;
};
void TextPrinter::PrintCallAttrs(const Expr& op,
const Attrs& attrs,
std::ostream& os) { // NOLINT(*)
if (!attrs.defined()) return;
if (const auto* op_node = op.as<OpNode>()) {
if (attrs->type_index() == op_node->attrs_type_index) {
AttrPrinter printer(os, this);
const_cast<BaseAttrsNode*>(attrs.operator->())
->VisitNonDefaultAttrs(&printer);
return;
}
}
os << ", " << meta_.GetMetaNode(attrs);
}
std::string RelayPrint(const NodeRef& node) {
return TextPrinter().Print(node);
}
TVM_REGISTER_API("relay._expr._text_print")
.set_body_typed<std::string(const NodeRef&)>(RelayPrint);
} // namespace relay
} // namespace tvm
......@@ -22,6 +22,8 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
return TensorTypeNode::make({}, dtype);
}
TVM_REGISTER_NODE_TYPE(TensorTypeNode);
TVM_REGISTER_API("relay._make.TensorType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Array<IndexExpr> shape = args[0];
......@@ -30,8 +32,8 @@ TVM_REGISTER_API("relay._make.TensorType")
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const TensorTypeNode *node,
tvm::IRPrinter *p) {
p->stream << "TensorTypeNode(" << node->dtype << ", " << node->shape << ")";
tvm::IRPrinter *p) {
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});
TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) {
......@@ -41,6 +43,8 @@ TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) {
return TypeParam(n);
}
TVM_REGISTER_NODE_TYPE(TypeParamNode);
TVM_REGISTER_API("relay._make.TypeParam")
.set_body([](TVMArgs args, TVMRetValue *ret) {
int kind = args[1];
......@@ -61,6 +65,8 @@ IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) {
return IncompleteType(n);
}
TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_API("relay._make.IncompleteType")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[0];
......@@ -86,6 +92,8 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
return FuncType(n);
}
TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_API("relay._make.FuncType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]);
......@@ -111,6 +119,8 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func,
return TypeRelation(n);
}
TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_API("relay._make.TypeRelation")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]);
......@@ -129,6 +139,8 @@ TupleType TupleTypeNode::make(Array<Type> fields) {
return TupleType(n);
}
TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_API("relay._make.TupleType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TupleTypeNode::make(args[0]);
......
......@@ -11,7 +11,7 @@
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(ConvAttrs);
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
bool Conv2DRel(const Array<Type>& types,
int num_inputs,
......@@ -25,7 +25,7 @@ bool Conv2DRel(const Array<Type>& types,
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");
const ConvAttrs* param = attrs.as<ConvAttrs>();
const Conv2DAttrs* param = attrs.as<Conv2DAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->weight_layout);
......@@ -113,7 +113,7 @@ Expr MakeConv2D(Expr data,
std::string weight_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_node<ConvAttrs>();
auto attrs = make_node<Conv2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
......@@ -148,6 +148,7 @@ with the layer input to produce a tensor of outputs.
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.Conv2DAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
......@@ -296,6 +297,7 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.Conv2DTransposeAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
......
......@@ -78,6 +78,7 @@ RELAY_REGISTER_OP("nn.dense")
- **out**: `(x1, x2, ..., xn, units)`.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.DenseAttrs")
.set_num_inputs(2)
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("weight", "2D Tensor", "Weight matrix.")
......@@ -107,6 +108,7 @@ RELAY_REGISTER_OP("nn.leaky_relu")
`y = x > 0 ? x : alpha * x`
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.LeakyReluAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "Input data.")
.set_support_level(3)
......@@ -135,6 +137,7 @@ RELAY_REGISTER_OP("nn.softmax")
- **data**: The input data
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SoftmaxAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
......@@ -163,6 +166,7 @@ RELAY_REGISTER_OP("nn.log_softmax")
- **data**: The input data
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SoftmaxAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
......@@ -171,9 +175,9 @@ RELAY_REGISTER_OP("nn.log_softmax")
// BatchFlatten
bool BatchFlattenRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
......@@ -278,6 +282,7 @@ centered at that value (zero padding is added where necessary).
- **data**: The input tensor.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.LRNAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
......@@ -296,12 +301,12 @@ Expr MakeL2Normalize(Expr data,
}
TVM_REGISTER_API("relay.op.nn._make.l2_normalize")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeL2Normalize, args, rv);
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeL2Normalize, args, rv);
});
RELAY_REGISTER_OP("nn.l2_normalize")
.describe(R"code(L2 Normalization layer.
.describe(R"code(L2 Normalization layer.
Normalizes along dimension axis using an L2 norm
......@@ -352,6 +357,7 @@ During training, each element of the input is set to zero with probability ``p``
The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input unchanged.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.DropoutAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "Input to which dropout will be applied.")
.set_support_level(1)
......@@ -478,6 +484,7 @@ axis to be the last item in the input shape.
.. note::
This operator can be optimized away for inference.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.BatchNormAttrs")
.set_num_inputs(5)
.add_argument("data", "Tensor", "Input to which batch_norm will be applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
......
......@@ -60,7 +60,7 @@ bool PadRel(const Array<Type>& types,
}
// Handler to create a call to the padding op used by front-end FFI
Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) {
Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) {
auto attrs = make_node<PadAttrs>();
attrs->pad_value = pad_value;
attrs->pad_width = std::move(pad_width);
......
......@@ -76,6 +76,7 @@ RELAY_REGISTER_OP("expand_dims")
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.ExpandDimsAttrs")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.add_type_rel("ExpandDims", ExpandDimsRel);
......@@ -481,6 +482,7 @@ RELAY_REGISTER_OP("zeros")
.describe(R"code(Fill array with zeros.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.InitOpAttrs")
.set_num_inputs(0)
.set_support_level(3)
.add_type_rel("InitOp", InitOpRel);
......@@ -503,6 +505,7 @@ RELAY_REGISTER_OP("ones")
.describe(R"code(Fill array with ones.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.InitOpAttrs")
.set_num_inputs(0)
.set_support_level(3)
.add_type_rel("InitOp", InitOpRel);
......@@ -697,6 +700,7 @@ RELAY_REGISTER_OP("squeeze")
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.SqueezeAttrs")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Squeeze", SqueezeRel);
......
......@@ -74,7 +74,10 @@ class CalcDep : private ExprMutator {
}
Expr VisitExpr_(const FunctionNode* f) final {
return FunctionNode::make(f->params, f->ret_type, Eliminate(f->body), f->type_params);
return FunctionNode::make(f->params,
Eliminate(f->body),
f->ret_type,
f->type_params);
}
// generate the let list from dependency graph
......
......@@ -20,6 +20,7 @@ class TypeFunctor;
#define TYPE_FUNCTOR_DEFAULT \
{ return VisitTypeDefault_(op, std::forward<Args>(args)...); }
#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
......
import tvm
from tvm import relay
from tvm.relay.expr import debug_print
from tvm.relay.ir_builder import IRBuilder
ib = IRBuilder()
def show(e):
r = debug_print(ib.env, e)
assert r is not None
def test_constant():
arr = tvm.nd.array(10)
const = relay.Constant(arr)
show(const)
# should print the array inside?
def test_tuple():
fields = tvm.convert([])
tup = relay.Tuple(fields)
show(tup)
def test_local_var():
name_hint = 's'
lv = relay.Var(name_hint)
show(lv)
def test_dup_var():
lv = relay.Var('s')
rv = relay.Var('s')
show(relay.Tuple([lv, rv]))
def test_large_dup_var():
av = relay.Var('s')
bv = relay.Var('s')
cv = relay.Var('s')
show(relay.Tuple([av, bv, cv]))
def test_global_var():
name_hint = 'g'
gv = relay.GlobalVar(name_hint)
gv.name_hint == name_hint
show(gv)
def test_function():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Var(n) for n in param_names])
ret_type = None
body = params[0]
type_params = tvm.convert([])
fn = relay.Function(params, ret_type, body, type_params)
show(fn)
def test_call():
op = relay.Var('f')
arg_names = ['a', 'b', 'c', 'd']
args = tvm.convert([relay.Var(n) for n in arg_names])
call = relay.Call(op, args, None, None)
show(call)
def test_let():
ty = relay.ty.TensorType((10, 20), 'float32')
lv = relay.Var('x', ty)
arr = tvm.nd.array(10)
value = relay.Constant(arr)
let = relay.Let(lv, value, lv)
show(let)
def test_if():
cond = relay.Var('cond')
left = relay.Var('left')
right = relay.Var('right')
ife = relay.If(cond, left, right)
show(ife)
def test_tuple_get_item():
t = relay.Var('t')
g = relay.TupleGetItem(t, 0)
show(g)
import tvm
import numpy as np
from tvm import relay
do_print = [False]
def show(text):
if do_print[0]:
print("---------------------------")
print(text)
def test_func():
x = relay.var("x", shape=(3, 2))
y = relay.var("y")
one = relay.const(10e10, dtype="float32")
z = relay.add(x, one)
z = relay.add(z, z)
f = relay.Function([x, y], z)
show(z.astext())
show(f.astext())
def test_env():
x = relay.var("x", "float32")
y = relay.var("y", "float32")
z = relay.add(x, y)
z = relay.add(z, z)
f = relay.Function([x, y], z)
env = relay.Environment()
env.add("myf", f)
text = env.astext()
assert "def @myf" in text
assert "%1 = add(%0, %0) # ty=float32" in text
show(text)
def test_meta_data():
n, c, h, w = tvm.var("n"), 10, 224, 224
x = relay.var("x", shape=(n, c, h, w))
w = relay.var("w")
z = relay.nn.conv2d(x, w,
kernel_size=(3, 3),
padding=(1, 1),
channels=2)
f = relay.Function([x, w], z)
text = f.astext()
assert "channels=2" in text
assert "meta.Variable(id=0)" in text
show(text)
text = relay.const([1,2,3]).astext()
assert "meta.relay.Constant(id=0)" in text
show(text)
def test_call_attrs():
x = relay.var("x")
# non default args
z = relay.nn.softmax(x, axis=2)
assert "axis=2" in z.astext()
# default args
z = relay.nn.softmax(x)
assert "softmax(%x)" in z.astext()
# non default args
z = relay.expand_dims(x, axis=2, num_newaxis=2)
assert "num_newaxis=2" in z.astext()
def test_let_if_scope():
x = relay.var("x", "float32")
y = relay.var("y", "float32")
cond = relay.var("cond", "bool")
v1 = relay.var("v")
v2 = relay.var("v", "float32")
then_branch = relay.Let(
v1, relay.const(1, "float32"),
relay.Let(v2, x, relay.subtract(v1, v2)))
v3 = relay.var("v")
let2 = relay.Let(v3, y, v3)
else_branch = relay.add(let2, let2)
result = relay.If(cond, then_branch, else_branch)
f = relay.Function([x, y, cond], result)
text = f.astext()
assert text.count("{") == 4
assert "%cond: bool" in text
show(f.astext())
if __name__ == "__main__":
do_print[0] = True
test_let_if_scope()
test_func()
test_env()
test_meta_data()
test_call_attrs()
......@@ -10,7 +10,7 @@ def test_well_formed():
let = relay.Let(x, v, x)
assert well_formed(let)
assert not well_formed(relay.Let(x, v, let))
f = relay.Function([x], ty, x)
f = relay.Function([x], x, ty)
assert well_formed(f)
# this test should pass in case of weak uniqueness (only test for shadowing)
# but we want all binder to be distinct from each other.
......
......@@ -262,49 +262,49 @@ def test_function_alpha_equal():
basic_args = [relay.Var("v3", tt1), relay.Var("v4", tt2)]
basic_tps = [tp1, tp2]
func = relay.Function([v1, v2],
tt2, v1, basic_tps)
mapped = relay.Function(basic_args, tt2, basic_args[0], basic_tps)
func = relay.Function([v1, v2], v1,
tt2, basic_tps)
mapped = relay.Function(basic_args, basic_args[0], tt2, basic_tps)
assert alpha_equal(func, mapped)
fewer_params = relay.Function([relay.Var("v4", tt2)], tt2, v4, basic_tps)
fewer_params = relay.Function([relay.Var("v4", tt2)], v4, tt2, basic_tps)
assert not alpha_equal(func, fewer_params)
more_params = relay.Function([relay.Var("v3", tt1),
relay.Var("v4", tt2),
relay.Var("v2", tt2)], tt2, v4, basic_tps)
relay.Var("v2", tt2)], v4, tt2, basic_tps)
assert not alpha_equal(func, more_params)
params_unordered = relay.Function([v2, v1],
tt2, v1, basic_tps)
params_unordered = relay.Function([v2, v1], v1,
tt2, basic_tps)
assert not alpha_equal(func, params_unordered)
params_mismatch = relay.Function([v1, v3],
tt2, v1, basic_tps)
params_mismatch = relay.Function([v1, v3], v1,
tt2, basic_tps)
assert not alpha_equal(func, params_mismatch)
# also would not typecheck
ret_type_mismatch = relay.Function(basic_args, tt1, v4, basic_tps)
ret_type_mismatch = relay.Function(basic_args, v4, tt1, basic_tps)
assert not alpha_equal(func, ret_type_mismatch)
# also mis-typed
different_body = relay.Function(basic_args, tt2, v3, basic_tps)
different_body = relay.Function(basic_args, v3, tt2, basic_tps)
assert not alpha_equal(func, different_body)
fewer_type_params = relay.Function(basic_args, tt2, v4, [tp1])
fewer_type_params = relay.Function(basic_args, v4, tt2, [tp1])
assert not alpha_equal(func, fewer_type_params)
more_type_params = relay.Function(basic_args, tt2, v4, [tp1, tp2, tp3])
more_type_params = relay.Function(basic_args, v4, tt2, [tp1, tp2, tp3])
assert not alpha_equal(func, more_type_params)
type_params_unordered = relay.Function(basic_args, tt2, v4, [tp2, tp1])
type_params_unordered = relay.Function(basic_args, v4, tt2, [tp2, tp1])
assert not alpha_equal(func, type_params_unordered)
different_type_params = relay.Function(basic_args, tt2, v4, [tp3, tp4])
different_type_params = relay.Function(basic_args, v4, tt2, [tp3, tp4])
assert not alpha_equal(func, different_type_params)
# a well-typed example that also differs in body, ret type, and type params
tupled_example = relay.Function(basic_args, tt3, relay.Tuple([v3, v4]))
tupled_example = relay.Function(basic_args, relay.Tuple([v3, v4]), tt3)
assert not alpha_equal(func, tupled_example)
......
......@@ -59,7 +59,7 @@ def test_recursion():
n = relay.Var("n", e.int32)
data = relay.Var("data", e.float32)
funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data)))
value = relay.Function([n, data], e.float32, funcbody, [])
value = relay.Function([n, data], funcbody, e.float32, [])
orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)))
assert alpha_equal(dead_code_elimination(orig), orig)
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three)
......
......@@ -13,7 +13,7 @@ def test_free_vars():
let = relay.Let(x, v, x)
fvx = free_vars(let)
assert len(free_vars(let)) == 0
f = relay.Function([x], ty, x)
f = relay.Function([x], x, ty)
assert len(free_vars(f)) == 0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment