Unverified Commit ec86d7f1 by Tianqi Chen Committed by GitHub

[REFACTOR] Streamline Function Attr interface. (#5045)

* [REFACTOR] Streamline Function Attr interface.

There has been quite a few recent changes that depends heavily on
the function attr interface. This PR streamlines that interface by introducing
two APIs that covers most of the usages.

- GetAttr which gets a typed object for a given key
  - HasNonzeroAttr is a quick helper that calls GetAttr to quickly check an attribute
- WithAttr that creates a new function object with the given attr
  - The API comes with copy on write optimization to avoid multiple copies
  - We deliberately pick the prefix With(instead of Set) to indicate this
    function does not mutate the original input.

On the python side:
- We allow read access via func.attrs (which is a DictAttr)
- func.with_attrs to create a new instance with updated attrs.

We also get rid of the small wrapper functions and make sure the API centered around
the GetAttr and HasNonzeroAttr interface.

This PR also changes the function construction to follow the new convention.

* Address review comments

* Address review comments

* Fix doxygen path
parent a9505365
...@@ -277,26 +277,13 @@ class BaseAttrsNode : public Object { ...@@ -277,26 +277,13 @@ class BaseAttrsNode : public Object {
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
}; };
/*! \brief Base attribute container for all attributes */ /*!
* \brief Managed reference to BaseAttrsNode.
* \sa AttrsNode, BaseAttrsNode
*/
class Attrs : public ObjectRef { class Attrs : public ObjectRef {
public: public:
// normal constructor TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode);
Attrs() {}
// construct from shared ptr.
explicit Attrs(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return The attribute node */
const BaseAttrsNode* operator->() const {
return ptr();
}
/*! \brief specify container node */
using ContainerType = BaseAttrsNode;
private:
/*! \return the internal attribute node */
const BaseAttrsNode* ptr() const {
return static_cast<const BaseAttrsNode*>(get());
}
}; };
/*! /*!
...@@ -309,12 +296,7 @@ class DictAttrsNode : public BaseAttrsNode { ...@@ -309,12 +296,7 @@ class DictAttrsNode : public BaseAttrsNode {
public: public:
/*! \brief internal attrs map */ /*! \brief internal attrs map */
Map<std::string, ObjectRef> dict; Map<std::string, ObjectRef> dict;
/*!
* \brief Consruct a Attrs backed by DictAttrsNode.
* \param dict The attributes.
* \return The dict attributes.
*/
TVM_DLL static Attrs make(Map<std::string, ObjectRef> dict);
// implementations // implementations
void VisitAttrs(AttrVisitor* v) final; void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final; void VisitNonDefaultAttrs(AttrVisitor* v) final;
...@@ -327,6 +309,23 @@ class DictAttrsNode : public BaseAttrsNode { ...@@ -327,6 +309,23 @@ class DictAttrsNode : public BaseAttrsNode {
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode); TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
}; };
/*!
* \brief Managed reference to DictAttrsNode
* \sa DictAttrsNode.
*/
class DictAttrs : public Attrs {
public:
/*!
* \brief Consruct a Attrs backed by DictAttrsNode.
* \param dict The attributes.
* \return The dict attributes.
*/
TVM_DLL explicit DictAttrs(Map<std::string, ObjectRef> dict);
TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
};
// Namespace containing detail implementations // Namespace containing detail implementations
namespace detail { namespace detail {
......
...@@ -211,30 +211,6 @@ class GlobalVar : public RelayExpr { ...@@ -211,30 +211,6 @@ class GlobalVar : public RelayExpr {
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode); TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
}; };
/*!
* \brief Base node of all functions.
*
* We support several variants of functions throughout the stack.
* All of the functions shares the same type system(via checked_type)
* to support cross variant calls.
*
* \sa BaseFunc
*/
class BaseFuncNode : public RelayExprNode {
public:
static constexpr const char* _type_key = "BaseFunc";
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
};
/*!
* \brief Managed reference to BaseFuncNode.
* \sa BaseFuncNode
*/
class BaseFunc : public RelayExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};
// PrimExprs that are useful as runtime containers. // PrimExprs that are useful as runtime containers.
// //
/*! /*!
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ir/function.h
* \brief Function nodes.
*/
#ifndef TVM_IR_FUNCTION_H_
#define TVM_IR_FUNCTION_H_
#include <tvm/ir/expr.h>
#include <tvm/ir/attrs.h>
#include <type_traits>
#include <string>
namespace tvm {
/*!
* \brief Base node of all functions.
*
* We support several variants of functions throughout the stack.
* All of the functions share the same type system(via checked_type)
* to support cross variant calls.
*
* \sa BaseFunc
*/
class BaseFuncNode : public RelayExprNode {
public:
/*! \brief Additional attributes storing the meta-data */
DictAttrs attrs;
/*!
* \brief Get a function attribute.
*
* \param attr_key The attribute key.
* \param default_value The default value if the key does not exist, defaults to nullptr.
*
* \return The result
*
* \tparam TOBjectRef the expected object type.
* \throw Error if the key exists but the value does not match TObjectRef
*
* \code
*
* void GetAttrExample(const BaseFunc& f) {
* Integer value = f->GetAttr<Integer>("AttrKey", 0);
* }
*
* \endcode
*/
template<typename TObjectRef>
TObjectRef GetAttr(const std::string& attr_key,
TObjectRef default_value = NullValue<TObjectRef>()) const {
static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
"Can only call GetAttr with ObjectRef types.");
if (!attrs.defined()) return default_value;
auto it = attrs->dict.find(attr_key);
if (it != attrs->dict.end()) {
return Downcast<TObjectRef>((*it).second);
} else {
return default_value;
}
}
/*!
* \brief Check whether the function has an non-zero integer attr.
*
* This function can be used to check whether an optional
* attribute mark(e.g. inline) exists.
*
* \param attr_key The key to the attribute.
* \return The check result.
*
* \code
*
* void HasNonzeroAttrExample(const BaseFunc& f) {
* if (f->HasNonzeroAttr(attr::kInline)) {
* // inline the function.
* }
* }
*
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const {
return GetAttr<Integer>(attr_key, 0)->value != 0;
}
static constexpr const char* _type_key = "BaseFunc";
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
};
/*!
* \brief Managed reference to BaseFuncNode.
* \sa BaseFuncNode
*/
class BaseFunc : public RelayExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/ir/type.h> #include <tvm/ir/type.h>
#include <tvm/ir/expr.h> #include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/ir/adt.h> #include <tvm/ir/adt.h>
#include <string> #include <string>
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/relay/adt.h> #include <tvm/relay/adt.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <string> #include <string>
......
...@@ -165,113 +165,6 @@ class Var : public Expr { ...@@ -165,113 +165,6 @@ class Var : public Expr {
}; };
/*! /*!
* \brief Function (subgraph in computational graph)
*/
class Function;
/*! \brief Function container */
class FunctionNode : public BaseFuncNode {
public:
/*! \brief Function parameters */
tvm::Array<Var> params;
/*!
* \brief
* The expression which represents the computation of the function,
* the expression may reference the parameters, and the type of it
* or sub-expressions may reference the type variables.
*/
Expr body;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*!
* \brief Type parameters of the function.
* Enables the function to vary its type based on these.
* This corresponds to template paramaters in c++'s terminology.
*
* \note This can be usually empty for non-polymorphic functions.
*/
tvm::Array<TypeVar> type_params;
/*!
* \brief The attributes which store metadata about functions.
*/
tvm::Attrs attrs;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
/*!
* \brief Return the derived function annotation of this expression.
*
* \return The function type annotation.
* \note The function type annotation can contain IncompleteType.
*/
TVM_DLL FuncType func_type_annotation() const;
/*!
* \brief Check whether the function is a primitive function.
*
* \return Whether the function is primitive or not.
*/
bool IsPrimitive() const;
/*!
* \brief Check whether the function is marked as inline.
*
* \return Whether the function should be inlined or not.
*/
bool IsMarkedInline() const;
/*!
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
*
* \return Whether the function will be compiled using the default compiler
* (e.g. those are used in the TVM stack).
*/
bool UseDefaultCompiler() const;
TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> ty_params,
tvm::Attrs attrs = Attrs());
/*!
* \brief Attach the function's parameters to its attributes for use in analysis.
* \return The function with its parameters attached.
*/
Function SetParams(const tvm::Map<Var, Constant>& parameters) const;
/*!
* \brief Retrieve the function's parameters.
*
* \return The function's parameter.
*/
tvm::Map<Var, Constant> GetParams() const;
static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
};
class Function : public BaseFunc {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
};
TVM_DLL ObjectRef FunctionGetAttr(const Function& func, const std::string& key);
TVM_DLL Function FunctionSetAttr(const Function& func,
const std::string& key,
const ObjectRef& data);
/*!
* \brief Call corresponds to operator invocation. * \brief Call corresponds to operator invocation.
* Corresponds to the operator in computational graph terminology. * Corresponds to the operator in computational graph terminology.
*/ */
...@@ -550,30 +443,6 @@ class TempExpr : public Expr { ...@@ -550,30 +443,6 @@ class TempExpr : public Expr {
TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode); TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode);
}; };
/*! \brief namespace of the attributes that are attached to a function. */
namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
/*!
* \brief Indicate the compiler that should be used for builing this function.
* When this is unset or set to "default", the default compilation pipeline will be used.
*/
constexpr const char* kCompiler = "Compiler";
/*! \brief Indicate if the function is a closure. */
constexpr const char* kClosure = "Closure";
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
constexpr const char* kParams = "__params__";
/*! \brief Store the unique external symbol for external compilers. */
constexpr const char* kExternalSymbol = "ExternalSymbol";
/*! \brief Mark if the function should be avoided being optimized. */
constexpr const char* kSkipOptimization = "SkipOptimization";
/*! \brief Treat the function as a composite operator. */
constexpr const char* kComposite = "Composite";
/*! \brief Mark the function to be inlined. */
constexpr const char* kInline = "Inline";
} // namespace attr
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_EXPR_H_ #endif // TVM_RELAY_EXPR_H_
...@@ -27,16 +27,15 @@ ...@@ -27,16 +27,15 @@
#include <tvm/node/functor.h> #include <tvm/node/functor.h>
#include <tvm/ir/error.h> #include <tvm/ir/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/op.h>
#include <string> #include <string>
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
#include "./expr.h"
#include "./adt.h"
#include "./op.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/function.h
* \brief Relay Function.
*/
#ifndef TVM_RELAY_FUNCTION_H_
#define TVM_RELAY_FUNCTION_H_
#include <tvm/ir/function.h>
#include <tvm/relay/expr.h>
#include <string>
namespace tvm {
namespace relay {
/*!
* \brief Relay Function container
* \sa Function
*/
class FunctionNode : public BaseFuncNode {
public:
/*! \brief Function parameters */
tvm::Array<Var> params;
/*!
* \brief
* The expression which represents the computation of the function,
* the expression may reference the parameters, and the type of it
* or sub-expressions may reference the type variables.
*/
Expr body;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*!
* \brief Type parameters of the function.
* Enables the function to vary its type based on these.
* This corresponds to template paramaters in c++'s terminology.
*
* \note This can be usually empty for non-polymorphic functions.
*/
tvm::Array<TypeVar> type_params;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
/*!
* \brief Return the derived function annotation of this expression.
*
* \return The function type annotation.
* \note The function type annotation can contain IncompleteType.
*/
TVM_DLL FuncType func_type_annotation() const;
/*!
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
*
* \return Whether the function will be compiled using the default compiler
* (e.g. those are used in the TVM stack).
*/
bool UseDefaultCompiler() const;
static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
};
/*!
* \brief Managed reference to FunctionNode.
* \sa FunctionNode
*/
class Function : public BaseFunc {
public:
/*!
* \brief Constructor
* \param params The parameters of the function.
* \param body The body of the function.
* \param ret_type The return type of the function.
* \param ty_params The type parameters.
* \param attrs Additional function attributes.
*/
TVM_DLL Function(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> ty_params,
tvm::DictAttrs attrs = NullValue<DictAttrs>());
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
};
/*!
* \brief Create a new function that copies func, but overrides
* the attribute value key with the value.
*
* \param func The input function.
* \param attr_key The attribute key.
* \param attr_value The value attribute value.
*
* \returns The new function with updated attributes.
*
* \note This function performs copy on write optimization for func.
* If we move a uniquely referenced func into WithAttr,
* then no additional copy will be performed.
*
* This is also why we make it as a function instead of a member function
* and why we pass by value in the first argument.
*
* \code
*
* // Recommended way to trigger copy on write
* func = WithAttr(std::move(func), "key1", value1);
* func = WithAttr(std::move(func), "key2", value2);
*
* \endcode
*/
TVM_DLL Function WithAttr(Function func, const std::string& attr_key, ObjectRef attr_value);
/*!
* \brief namespace of the attributes that can be attached to a relay::Function.
*/
namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
/*!
* \brief Indicate the compiler that should be used for builing this function.
* When this is unset or set to "default", the default compilation pipeline will be used.
*/
constexpr const char* kCompiler = "Compiler";
/*! \brief Indicate if the function is a closure. */
constexpr const char* kClosure = "Closure";
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
constexpr const char* kParams = "__params__";
/*! \brief Store the unique external symbol for external compilers. */
constexpr const char* kExternalSymbol = "ExternalSymbol";
/*! \brief Mark if the function should be avoided being optimized. */
constexpr const char* kSkipOptimization = "SkipOptimization";
/*! \brief Treat the function as a composite operator. */
constexpr const char* kComposite = "Composite";
/*! \brief Mark the function to be inlined. */
constexpr const char* kInline = "Inline";
} // namespace attr
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_FUNCTION_H_
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/ir/transform.h> #include <tvm/ir/transform.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
......
...@@ -24,7 +24,7 @@ from .type_relation import TypeCall, TypeRelation ...@@ -24,7 +24,7 @@ from .type_relation import TypeCall, TypeRelation
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range
from .adt import Constructor, TypeData from .adt import Constructor, TypeData
from .module import IRModule from .module import IRModule
from .attrs import Attrs, make_node from .attrs import Attrs, DictAttrs, make_node
from .container import Array, Map from .container import Array, Map
from . import transform from . import transform
...@@ -47,9 +47,7 @@ class Attrs(Object): ...@@ -47,9 +47,7 @@ class Attrs(Object):
keys : list of str keys : list of str
List of keys List of keys
""" """
fields = self.list_field_info() return [field.name for field in self.list_field_info()]
for field in fields:
yield field.name
def get_int_tuple(self, key): def get_int_tuple(self, key):
"""Get a python int tuple of a key """Get a python int tuple of a key
...@@ -93,6 +91,39 @@ class Attrs(Object): ...@@ -93,6 +91,39 @@ class Attrs(Object):
def __getitem__(self, item): def __getitem__(self, item):
return self.__getattr__(item) return self.__getattr__(item)
@tvm._ffi.register_object
class DictAttrs(Attrs):
"""Dictionary attributes.
"""
def _dict(self):
"""Get internal dict"""
return _ffi_api.DictAttrsGetDict(self)
def keys(self):
"""Get list of names in the attribute.
Returns
-------
keys : list of str
List of keys
"""
return [k for k, _ in self.items()]
def __getitem__(self, k):
return self._dict().__getitem__(k)
def __contains__(self, k):
return self._dict().__contains__(k)
def items(self):
"""Get items from the map."""
return self._dict().items()
def __len__(self):
return self._dict().__len__()
def make_node(type_key, **kwargs): def make_node(type_key, **kwargs):
"""Make a new IR node by its type key and fields """Make a new IR node by its type key and fields
......
...@@ -53,6 +53,11 @@ class RelayExpr(BaseExpr): ...@@ -53,6 +53,11 @@ class RelayExpr(BaseExpr):
class BaseFunc(RelayExpr): class BaseFunc(RelayExpr):
"""Base class of all functions.""" """Base class of all functions."""
@property
def attrs(self):
"""Return the attrs member of the function.
"""
return _ffi_api.BaseFunc_Attrs(self)
@tvm._ffi.register_object("relay.GlobalVar") @tvm._ffi.register_object("relay.GlobalVar")
......
...@@ -266,22 +266,24 @@ class Function(BaseFunc): ...@@ -266,22 +266,24 @@ class Function(BaseFunc):
""" """
return Call(self, args, None, None) return Call(self, args, None, None)
def get_params(self): def with_attr(self, attr_key, attr_value):
return _expr.FunctionGetParams(self) """Create a new copy of the function and update the attribute
def set_params(self, params): Parameters
for key in params: ----------
value = params[key] attr_key : str
if isinstance(value, NDArray): The attribute key to use.
params[key] = Constant(value)
return _expr.FunctionSetParams(self, params) attr_value : Object
The new attribute value.
def set_attribute(self, name, ref): Returns
return _expr.FunctionSetAttr(self, name, ref) -------
func : Function
A new copy of the function
"""
return _expr.FunctionWithAttr(self, attr_key, attr_value)
def get_attribute(self, name):
return _expr.FunctionGetAttr(self, name)
@register_relay_node @register_relay_node
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file src/tvm/ir/adt.cc * \file src/ir/adt.cc
* \brief ADT type definitions. * \brief ADT type definitions.
*/ */
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
......
...@@ -53,22 +53,27 @@ Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const { ...@@ -53,22 +53,27 @@ Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
return {}; return {};
} }
Attrs DictAttrsNode::make(Map<std::string, ObjectRef> dict) { DictAttrs::DictAttrs(Map<std::string, ObjectRef> dict) {
ObjectPtr<DictAttrsNode> n = make_object<DictAttrsNode>(); ObjectPtr<DictAttrsNode> n = make_object<DictAttrsNode>();
n->dict = std::move(dict); n->dict = std::move(dict);
return Attrs(n); data_ = std::move(n);
} }
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, ReprPrinter* p) { .set_dispatch<DictAttrsNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const DictAttrsNode*>(node.get()); auto* op = static_cast<const DictAttrsNode*>(node.get());
p->stream << op->dict; p->stream << op->dict;
}); });
TVM_REGISTER_NODE_TYPE(DictAttrsNode); TVM_REGISTER_NODE_TYPE(DictAttrsNode);
TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode); TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode);
TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict")
.set_body_typed([](DictAttrs attrs) {
return attrs->dict;
});
using namespace tir; using namespace tir;
// Equal handler. // Equal handler.
......
...@@ -18,11 +18,12 @@ ...@@ -18,11 +18,12 @@
*/ */
/*! /*!
* \file src/tvm/ir/expr.cc * \file src/ir/expr.cc
* \brief The expression AST nodes for the common IR infra. * \brief The expression AST nodes for the common IR infra.
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/ir/expr.h> #include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
// NOTE: reverse dependency on top/tir. // NOTE: reverse dependency on top/tir.
// These dependencies do not happen at the interface-level, // These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked. // and are only used in minimum cases where they are clearly marked.
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/ir/function.cc
* \brief The function data structure.
*/
#include <tvm/runtime/registry.h>
#include <tvm/ir/function.h>
namespace tvm {
TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs")
.set_body_typed([](BaseFunc func) {
return func->attrs;
});
} // namespace tvm
...@@ -138,7 +138,7 @@ relay::Function RunTypeCheck(const IRModule& mod, ...@@ -138,7 +138,7 @@ relay::Function RunTypeCheck(const IRModule& mod,
<< std::endl; << std::endl;
} }
func = func =
relay::FunctionNode::make(concat(func->params, fv), relay::Function(concat(func->params, fv),
func->body, func->body,
func->ret_type, func->ret_type,
concat(func->type_params, ftv), concat(func->type_params, ftv),
...@@ -296,7 +296,7 @@ IRModule IRModule::FromExpr( ...@@ -296,7 +296,7 @@ IRModule IRModule::FromExpr(
if (auto* func_node = expr.as<relay::FunctionNode>()) { if (auto* func_node = expr.as<relay::FunctionNode>()) {
func = GetRef<relay::Function>(func_node); func = GetRef<relay::Function>(func_node);
} else { } else {
func = relay::FunctionNode::make( func = relay::Function(
relay::FreeVars(expr), expr, Type(), relay::FreeVars(expr), expr, Type(),
relay::FreeTypeVars(expr, mod), {}); relay::FreeTypeVars(expr, mod), {});
} }
...@@ -363,7 +363,7 @@ TVM_REGISTER_GLOBAL("ir.Module_Add") ...@@ -363,7 +363,7 @@ TVM_REGISTER_GLOBAL("ir.Module_Add")
auto func = mod_copy->Lookup(gv->name_hint); auto func = mod_copy->Lookup(gv->name_hint);
mod->Add(var, Downcast<relay::Function>(func), update); mod->Add(var, Downcast<relay::Function>(func), update);
} else { } else {
auto func = relay::FunctionNode::make({}, Downcast<RelayExpr>(val), Type(nullptr), {}); auto func = relay::Function({}, Downcast<RelayExpr>(val), Type(nullptr), {});
mod->Add(var, func, update); mod->Add(var, func, update);
} }
*ret = mod; *ret = mod;
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file src/tvm/ir/op.cc * \file src/ir/op.cc
* \brief Primitive operators and intrinsics. * \brief Primitive operators and intrinsics.
*/ */
#include <tvm/ir/op.h> #include <tvm/ir/op.h>
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file src/tvm/ir/tensor_type.cc * \file src/ir/tensor_type.cc
* \brief The type system AST nodes of Relay. * \brief The type system AST nodes of Relay.
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file src/tvm/ir/type.cc * \file src/ir/type.cc
* \brief Common type system AST nodes throughout the IR. * \brief Common type system AST nodes throughout the IR.
*/ */
#include <tvm/ir/type.h> #include <tvm/ir/type.h>
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file src/tvm/ir/type_relation.cc * \file src/ir/type_relation.cc
* \brief Type relation * \brief Type relation
*/ */
#include <tvm/ir/type.h> #include <tvm/ir/type.h>
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file src/tvm/relay/doc.cc * \file src/relay/doc.cc
* \brief Doc ADT used for pretty printing. * \brief Doc ADT used for pretty printing.
* *
* Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98 * Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/runtime/object.h> #include <tvm/runtime/object.h>
#include <memory> #include <memory>
#include <string> #include <string>
......
...@@ -63,7 +63,7 @@ FeatureSet DetectFeature(const Expr& expr) { ...@@ -63,7 +63,7 @@ FeatureSet DetectFeature(const Expr& expr) {
DETECT_DEFAULT_CONSTRUCT(Tuple) DETECT_DEFAULT_CONSTRUCT(Tuple)
DETECT_DEFAULT_CONSTRUCT(TupleGetItem) DETECT_DEFAULT_CONSTRUCT(TupleGetItem)
DETECT_CONSTRUCT(Function, { DETECT_CONSTRUCT(Function, {
if (!op->IsPrimitive()) { if (!op->HasNonzeroAttr(attr::kPrimitive)) {
ExprVisitor::VisitExpr_(op); ExprVisitor::VisitExpr_(op);
} }
}) })
......
...@@ -666,7 +666,7 @@ TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver") ...@@ -666,7 +666,7 @@ TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver")
ErrorReporter *err_reporter = new ErrorReporter(); ErrorReporter *err_reporter = new ErrorReporter();
auto module = IRModule({}, {}); auto module = IRModule({}, {});
auto dummy_fn_name = GlobalVar("test"); auto dummy_fn_name = GlobalVar("test");
module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {})); module->Add(dummy_fn_name, Function({}, TupleNode::make({}), Type(), {}, {}));
auto solver = std::make_shared<TypeSolver>(dummy_fn_name, module, err_reporter); auto solver = std::make_shared<TypeSolver>(dummy_fn_name, module, err_reporter);
auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc { auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc {
......
...@@ -617,15 +617,14 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -617,15 +617,14 @@ class CompileEngineImpl : public CompileEngineNode {
auto src_func = it.first->source_func; auto src_func = it.first->source_func;
CHECK(src_func.defined()); CHECK(src_func.defined());
if (!src_func->UseDefaultCompiler()) { if (!src_func->UseDefaultCompiler()) {
auto compiler = FunctionGetAttr(src_func, attr::kCompiler); auto code_gen = src_func->GetAttr<tir::StringImm>(attr::kCompiler);
const tvm::tir::StringImmNode* code_gen = compiler.as<tvm::tir::StringImmNode>(); CHECK(code_gen.defined()) << "No external codegen is set";
CHECK(code_gen) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) { if (ext_mods.find(code_gen->value) == ext_mods.end()) {
ext_mods[code_gen->value] = IRModule({}, {}); ext_mods[code_gen->value] = IRModule({}, {});
} }
auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol); auto symbol_name = src_func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
const tvm::tir::StringImmNode* symbol_name = ext_symbol.as<tvm::tir::StringImmNode>(); CHECK(symbol_name.defined()) << "No external symbol is set for:\n"
CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false); << AsText(src_func, false);
auto gv = GlobalVar(symbol_name->value); auto gv = GlobalVar(symbol_name->value);
ext_mods[code_gen->value]->Add(gv, src_func); ext_mods[code_gen->value]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first); cached_ext_funcs.push_back(it.first);
...@@ -694,8 +693,9 @@ class CompileEngineImpl : public CompileEngineNode { ...@@ -694,8 +693,9 @@ class CompileEngineImpl : public CompileEngineNode {
if (!key->source_func->UseDefaultCompiler()) { if (!key->source_func->UseDefaultCompiler()) {
auto cache_node = make_object<CachedFuncNode>(); auto cache_node = make_object<CachedFuncNode>();
const auto name_node = const auto name_node =
FunctionGetAttr(key->source_func, attr::kExternalSymbol).as<tvm::tir::StringImmNode>(); key->source_func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
CHECK(name_node != nullptr) << "External function has not been attached a name yet."; CHECK(name_node.defined())
<< "External function has not been attached a name yet.";
cache_node->func_name = name_node->value; cache_node->func_name = name_node->value;
cache_node->target = tvm::target::ext_dev(); cache_node->target = tvm::target::ext_dev();
value->cached_func = CachedFunc(cache_node); value->cached_func = CachedFunc(cache_node);
......
...@@ -68,8 +68,8 @@ class CSourceModuleCodegenBase { ...@@ -68,8 +68,8 @@ class CSourceModuleCodegenBase {
*/ */
std::string GetExtSymbol(const Function& func) const { std::string GetExtSymbol(const Function& func) const {
const auto name_node = const auto name_node =
FunctionGetAttr(func, attr::kExternalSymbol).as<tvm::tir::StringImmNode>(); func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
CHECK(name_node != nullptr) << "Fail to retrieve external symbol."; CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
std::string ext_symbol = name_node->value; std::string ext_symbol = name_node->value;
return ext_symbol; return ext_symbol;
} }
......
...@@ -415,7 +415,7 @@ class GraphRuntimeCodegen ...@@ -415,7 +415,7 @@ class GraphRuntimeCodegen
} else { } else {
LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey(); LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
} }
if (!func->IsPrimitive()) { if (!func->HasNonzeroAttr(attr::kPrimitive)) {
LOG(FATAL) << "TVM only support calls to primitive functions " LOG(FATAL) << "TVM only support calls to primitive functions "
<< "(i.e functions composed of fusable operator invocations)"; << "(i.e functions composed of fusable operator invocations)";
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file src/tvm/relay/interpreter.cc * \file src/relay/interpreter.cc
* \brief An interpreter for the Relay IR. * \brief An interpreter for the Relay IR.
*/ */
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
...@@ -516,7 +516,7 @@ class Interpreter : ...@@ -516,7 +516,7 @@ class Interpreter :
} }
if (is_dyn) { if (is_dyn) {
CHECK(func->IsPrimitive()); CHECK(func->HasNonzeroAttr(attr::kPrimitive));
out_shapes = ComputeDynamicShape(func, args); out_shapes = ComputeDynamicShape(func, args);
} }
...@@ -556,7 +556,7 @@ class Interpreter : ...@@ -556,7 +556,7 @@ class Interpreter :
const tvm::Array<ObjectRef>& args, const tvm::Array<ObjectRef>& args,
const Var& bind = Var()) { const Var& bind = Var()) {
// Get a reference to the function inside the closure. // Get a reference to the function inside the closure.
if (closure->func->IsPrimitive()) { if (closure->func->HasNonzeroAttr(attr::kPrimitive)) {
return InvokePrimitiveOp(closure->func, args); return InvokePrimitiveOp(closure->func, args);
} }
auto func = closure->func; auto func = closure->func;
......
...@@ -442,7 +442,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -442,7 +442,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
const Expr& outputs) { const Expr& outputs) {
std::vector<Index> argument_registers; std::vector<Index> argument_registers;
CHECK(func->IsPrimitive()) CHECK_NE(func->GetAttr<Integer>(attr::kPrimitive, 0)->value, 0)
<< "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; << "internal error: invoke_tvm_op requires the first argument to be a relay::Function";
auto input_tuple = inputs.as<TupleNode>(); auto input_tuple = inputs.as<TupleNode>();
...@@ -650,7 +650,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -650,7 +650,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
} }
void VisitExpr_(const FunctionNode* func_node) { void VisitExpr_(const FunctionNode* func_node) {
if (!func_node->IsPrimitive()) { if (!func_node->HasNonzeroAttr(attr::kPrimitive)) {
LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl
<< "Program: " << AsText(GetRef<Function>(func_node), false) << std::endl << "Program: " << AsText(GetRef<Function>(func_node), false) << std::endl
<< "AST: " << GetRef<Function>(func_node); << "AST: " << GetRef<Function>(func_node);
......
...@@ -86,7 +86,7 @@ struct PrimitiveInliner : ExprMutator { ...@@ -86,7 +86,7 @@ struct PrimitiveInliner : ExprMutator {
} }
if (auto func = op.as<FunctionNode>()) { if (auto func = op.as<FunctionNode>()) {
if (func->IsPrimitive()) { if (func->HasNonzeroAttr(attr::kPrimitive)) {
tvm::Array<Expr> call_args; tvm::Array<Expr> call_args;
for (auto arg : call->args) { for (auto arg : call->args) {
auto new_arg = VisitExpr(arg); auto new_arg = VisitExpr(arg);
...@@ -109,7 +109,7 @@ struct PrimitiveInliner : ExprMutator { ...@@ -109,7 +109,7 @@ struct PrimitiveInliner : ExprMutator {
} }
Expr VisitExpr_(const FunctionNode* func) { Expr VisitExpr_(const FunctionNode* func) {
if (func->IsPrimitive()) { if (func->HasNonzeroAttr(attr::kPrimitive)) {
return GetRef<Function>(func); return GetRef<Function>(func);
} else { } else {
return ExprMutator::VisitExpr_(func); return ExprMutator::VisitExpr_(func);
...@@ -128,7 +128,7 @@ struct PrimitiveInliner : ExprMutator { ...@@ -128,7 +128,7 @@ struct PrimitiveInliner : ExprMutator {
DLOG(INFO) << "Before inlining primitives: " << global DLOG(INFO) << "Before inlining primitives: " << global
<< std::endl << AsText(func, false); << std::endl << AsText(func, false);
func = FunctionNode::make(func->params, func = Function(func->params,
VisitExpr(func->body), VisitExpr(func->body),
func->ret_type, func->ret_type,
func->type_params, func->type_params,
......
...@@ -43,13 +43,11 @@ inline std::string GenerateName(const Function& func) { ...@@ -43,13 +43,11 @@ inline std::string GenerateName(const Function& func) {
} }
bool IsClosure(const Function& func) { bool IsClosure(const Function& func) {
ObjectRef res = FunctionGetAttr(func, attr::kClosure); return func->GetAttr<Integer>(attr::kClosure, 0)->value != 0;
const tir::IntImmNode* pval = res.as<tir::IntImmNode>();
return pval && pval->value != 0;
} }
Function MarkClosure(const Function& func) { Function MarkClosure(Function func) {
return FunctionSetAttr(func, attr::kClosure, tvm::Integer(1)); return WithAttr(std::move(func), attr::kClosure, tvm::Integer(1));
} }
/* The goal of this class is to lift out any nested functions into top-level /* The goal of this class is to lift out any nested functions into top-level
...@@ -65,7 +63,7 @@ class LambdaLifter : public ExprMutator { ...@@ -65,7 +63,7 @@ class LambdaLifter : public ExprMutator {
Expr VisitExpr_(const LetNode* let_node) final { Expr VisitExpr_(const LetNode* let_node) final {
bool is_lambda = false; bool is_lambda = false;
if (auto func = let_node->value.as<FunctionNode>()) { if (auto func = let_node->value.as<FunctionNode>()) {
if (!func->IsPrimitive()) { if (!func->HasNonzeroAttr(attr::kPrimitive)) {
is_lambda = true; is_lambda = true;
letrec_.push_back(let_node->var); letrec_.push_back(let_node->var);
} }
...@@ -96,7 +94,7 @@ class LambdaLifter : public ExprMutator { ...@@ -96,7 +94,7 @@ class LambdaLifter : public ExprMutator {
auto func = GetRef<Function>(func_node); auto func = GetRef<Function>(func_node);
// We should not transform primitive functions. // We should not transform primitive functions.
if (func->IsPrimitive()) { if (func->HasNonzeroAttr(attr::kPrimitive)) {
return std::move(func); return std::move(func);
} }
...@@ -151,10 +149,10 @@ class LambdaLifter : public ExprMutator { ...@@ -151,10 +149,10 @@ class LambdaLifter : public ExprMutator {
// code for the closure. // code for the closure.
Function lifted_func; Function lifted_func;
if (captured_vars.size() == 0 && free_type_vars.size() == 0) { if (captured_vars.size() == 0 && free_type_vars.size() == 0) {
lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, body->type_params); lifted_func = Function(body->params, body->body, body->ret_type, body->type_params);
} else { } else {
lifted_func = lifted_func =
FunctionNode::make(captured_vars, body, func->func_type_annotation(), free_type_vars); Function(captured_vars, body, func->func_type_annotation(), free_type_vars);
lifted_func = MarkClosure(lifted_func); lifted_func = MarkClosure(lifted_func);
} }
...@@ -191,7 +189,7 @@ class LambdaLifter : public ExprMutator { ...@@ -191,7 +189,7 @@ class LambdaLifter : public ExprMutator {
if (auto* n = pair.second.as<FunctionNode>()) { if (auto* n = pair.second.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue; if (!n->UseDefaultCompiler()) continue;
auto func = GetRef<Function>(n); auto func = GetRef<Function>(n);
func = FunctionNode::make(func->params, func = Function(func->params,
VisitExpr(func->body), VisitExpr(func->body),
func->ret_type, func->ret_type,
func->type_params, func->type_params,
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file src/tvm/ir/adt.cc * \file src/ir/adt.cc
* \brief AST nodes for Relay algebraic data types (ADTs). * \brief AST nodes for Relay algebraic data types (ADTs).
*/ */
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file src/tvm/relay/ir/expr.cc * \file src/relay/ir/expr.cc
* \brief The expression AST nodes of Relay. * \brief The expression AST nodes of Relay.
*/ */
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
...@@ -110,118 +110,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -110,118 +110,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ")"; p->stream << ")";
}); });
Function FunctionNode::make(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> type_params,
tvm::Attrs attrs) {
ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
CHECK(params.defined());
CHECK(type_params.defined());
n->params = std::move(params);
n->body = std::move(body);
n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params);
n->attrs = std::move(attrs);
return Function(n);
}
FuncType FunctionNode::func_type_annotation() const {
Array<Type> param_types;
for (auto param : this->params) {
Type param_type = (param->type_annotation.defined()) ? param->type_annotation
: IncompleteType(Kind::kType);
param_types.push_back(param_type);
}
Type ret_type = (this->ret_type.defined()) ? this->ret_type
: IncompleteType(Kind::kType);
return FuncType(param_types, ret_type, this->type_params, {});
}
bool FunctionNode::IsPrimitive() const {
ObjectRef res = FunctionGetAttr(GetRef<Function>(this), attr::kPrimitive);
const tir::IntImmNode* pval = res.as<tir::IntImmNode>();
return pval && pval->value != 0;
}
bool FunctionNode::IsMarkedInline() const {
ObjectRef res = FunctionGetAttr(GetRef<Function>(this), attr::kInline);
const tir::IntImmNode* pval = res.as<tir::IntImmNode>();
return pval && pval->value != 0;
}
Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) const {
return FunctionSetAttr(GetRef<Function>(this), attr::kParams, parameters);
}
TVM_REGISTER_GLOBAL("relay._expr.FunctionSetParams")
.set_body_typed(
[](const Function& func, const tvm::Map<Var, Constant>& parameters) {
return func->SetParams(parameters);
});
tvm::Map<Var, Constant> FunctionNode::GetParams() const {
auto node_ref = FunctionGetAttr(GetRef<Function>(this), attr::kParams);
return Downcast<tvm::Map<Var, Constant>>(node_ref);
}
TVM_REGISTER_GLOBAL("relay._expr.FunctionGetParams")
.set_body_typed([](const Function& func) {
return func->GetParams();
});
bool FunctionNode::UseDefaultCompiler() const {
ObjectRef res = FunctionGetAttr(GetRef<Function>(this), attr::kCompiler);
const tir::StringImmNode* pval = res.as<tir::StringImmNode>();
return pval == nullptr || pval->value == "default";
}
ObjectRef FunctionGetAttr(const Function& func, const std::string& key) {
if (!func->attrs.defined()) { return ObjectRef(); }
const DictAttrsNode* dict_attrs = func->attrs.as<DictAttrsNode>();
CHECK(dict_attrs);
auto it = dict_attrs->dict.find(key);
if (it != dict_attrs->dict.end()) {
return (*it).second;
} else {
return ObjectRef();
}
}
Function FunctionSetAttr(const Function& func, const std::string& key, const ObjectRef& data) {
const DictAttrsNode* dattrs = func->attrs.as<DictAttrsNode>();
Attrs func_attrs;
if (dattrs) {
Map<std::string, ObjectRef> dict = dattrs->dict;
dict.Set(key, data);
func_attrs = DictAttrsNode::make(dict);
} else {
Map<std::string, ObjectRef> dict = {{key, data}};
func_attrs = DictAttrsNode::make(dict);
}
return FunctionNode::make(
func->params,
func->body,
func->ret_type,
func->type_params,
func_attrs);
}
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_GLOBAL("relay._make.Function")
.set_body_typed(FunctionNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const FunctionNode*>(ref.get());
p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
<< ", " << node->body << ", " << node->type_params << ", "
<< node->attrs << ")";
});
Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs, Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
Array<Type> type_args) { Array<Type> type_args) {
...@@ -360,18 +248,6 @@ TVM_REGISTER_GLOBAL("relay._expr.TempExprRealize") ...@@ -360,18 +248,6 @@ TVM_REGISTER_GLOBAL("relay._expr.TempExprRealize")
return temp->Realize(); return temp->Realize();
}); });
TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr")
.set_body_typed(
[](Function func, std::string name, ObjectRef ref) {
return FunctionSetAttr(func, name, ref);
});
TVM_REGISTER_GLOBAL("relay._expr.FunctionGetAttr")
.set_body_typed(
[](Function func, std::string name) {
return FunctionGetAttr(func, name);
});
TVM_REGISTER_GLOBAL("relay._make.Any") TVM_REGISTER_GLOBAL("relay._make.Any")
.set_body_typed([]() { return Any::make(); }); .set_body_typed([]() { return Any::make(); });
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file src/tvm/relay/expr_functor.cc * \file src/relay/expr_functor.cc
* \brief A wrapper around ExprFunctor which functionally updates the AST. * \brief A wrapper around ExprFunctor which functionally updates the AST.
* *
* ExprMutator uses memoization and self return in order to amortize * ExprMutator uses memoization and self return in order to amortize
...@@ -109,7 +109,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { ...@@ -109,7 +109,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
body.same_as(op->body)) { body.same_as(op->body)) {
return GetRef<Expr>(op); return GetRef<Expr>(op);
} else { } else {
return FunctionNode::make(params, body, ret_type, ty_params, op->attrs); return Function(params, body, ret_type, ty_params, op->attrs);
} }
} }
...@@ -417,7 +417,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) { ...@@ -417,7 +417,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
new_params.size() == func->params.size()) { new_params.size() == func->params.size()) {
return expr; return expr;
} }
auto ret = FunctionNode::make(new_params, auto ret = Function(new_params,
new_body, new_body,
func->ret_type, func->ret_type,
func->type_params, func->type_params,
...@@ -431,7 +431,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) { ...@@ -431,7 +431,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
new_params.push_back(v); new_params.push_back(v);
} }
} }
ret = FunctionNode::make(new_params, ret = Function(new_params,
new_body, new_body,
func->ret_type, func->ret_type,
func->type_params, func->type_params,
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/ir/function.cc
* \brief Function in relay.
*/
#include <tvm/relay/function.h>
namespace tvm {
namespace relay {
Function::Function(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> type_params,
DictAttrs attrs) {
ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
CHECK(params.defined());
CHECK(type_params.defined());
n->params = std::move(params);
n->body = std::move(body);
n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params);
n->attrs = std::move(attrs);
data_ = std::move(n);
}
FuncType FunctionNode::func_type_annotation() const {
Array<Type> param_types;
for (auto param : this->params) {
Type param_type = (param->type_annotation.defined()) ? param->type_annotation
: IncompleteType(Kind::kType);
param_types.push_back(param_type);
}
Type ret_type = (this->ret_type.defined()) ? this->ret_type
: IncompleteType(Kind::kType);
return FuncType(param_types, ret_type, this->type_params, {});
}
bool FunctionNode::UseDefaultCompiler() const {
tir::StringImm val = this->GetAttr<tir::StringImm>(attr::kCompiler);
return !val.defined() || val->value == "default";
}
Function WithAttr(Function func, const std::string& attr_key, ObjectRef attr_value) {
FunctionNode* node = func.CopyOnWrite();
if (node->attrs.defined()) {
node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
} else {
Map<std::string, ObjectRef> dict = {{attr_key, attr_value}};
node->attrs = DictAttrs(dict);
}
return func;
}
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_GLOBAL("relay._make.Function")
.set_body_typed([](tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> ty_params,
tvm::DictAttrs attrs) {
return Function(params, body, ret_type, ty_params, attrs);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const FunctionNode*>(ref.get());
p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
<< ", " << node->body << ", " << node->type_params << ", "
<< node->attrs << ")";
});
TVM_REGISTER_GLOBAL("relay._expr.FunctionWithAttr")
.set_body_typed(
[](Function func, std::string name, ObjectRef ref) {
return WithAttr(std::move(func), name, ref);
});
} // namespace relay
} // namespace tvm
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file src/tvm/relay/ir/hash.cc * \file src/relay/ir/hash.cc
* \brief Hash functions for Relay types and expressions. * \brief Hash functions for Relay types and expressions.
*/ */
#include <tvm/ir/type_functor.h> #include <tvm/ir/type_functor.h>
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* \file src/tvm/relay/ir/op_strategy.cc * \file src/relay/ir/op_strategy.cc
* \brief The Relay operator Strategy and related data structure. * \brief The Relay operator Strategy and related data structure.
*/ */
......
...@@ -139,9 +139,8 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, ...@@ -139,9 +139,8 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
} }
bool FunctionPassNode::SkipFunction(const Function& func) const { bool FunctionPassNode::SkipFunction(const Function& func) const {
ObjectRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization); return func->GetAttr<Integer>(attr::kSkipOptimization, 0)->value != 0 ||
const tir::IntImmNode* pval = skip_opt.as<tir::IntImmNode>(); !(func->UseDefaultCompiler());
return (pval && pval->value != 0) || (!func->UseDefaultCompiler());
} }
Pass CreateFunctionPass( Pass CreateFunctionPass(
......
...@@ -99,7 +99,7 @@ Pass QuantizeAnnotate() { ...@@ -99,7 +99,7 @@ Pass QuantizeAnnotate() {
for (const auto& x : FreeVars(func)) { for (const auto& x : FreeVars(func)) {
new_params.push_back(x); new_params.push_back(x);
} }
return FunctionNode::make(new_params, return Function(new_params,
func->body, func->body,
func->ret_type, func->ret_type,
func->type_params, func->type_params,
......
...@@ -151,7 +151,7 @@ class StatsCollector : private ExprMutator { ...@@ -151,7 +151,7 @@ class StatsCollector : private ExprMutator {
const FunctionNode* func = new_e.as<FunctionNode>(); const FunctionNode* func = new_e.as<FunctionNode>();
CHECK(func) << "Input shoule be Function"; CHECK(func) << "Input shoule be Function";
Expr new_body = TupleNode::make(std::move(profile_data_)); Expr new_body = TupleNode::make(std::move(profile_data_));
return FunctionNode::make(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params, return Function(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
func->attrs); func->attrs);
} }
......
...@@ -78,7 +78,7 @@ Expr DeDup(const Expr& e) { ...@@ -78,7 +78,7 @@ Expr DeDup(const Expr& e) {
for (const Var& param : op->params) { for (const Var& param : op->params) {
params.push_back(Fresh(param)); params.push_back(Fresh(param));
} }
return FunctionNode::make(params, return Function(params,
VisitExpr(op->body), VisitExpr(op->body),
VisitType(op->ret_type), VisitType(op->ret_type),
type_params, type_params,
......
...@@ -521,13 +521,13 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { ...@@ -521,13 +521,13 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) {
} }
CHECK_GT(new_body.size(), 0U); CHECK_GT(new_body.size(), 0U);
if (new_body.size() == 1) { if (new_body.size() == 1) {
return FunctionNode::make(params, new_body[0], Type(nullptr), return Function(params, new_body[0], Type(nullptr),
fn->type_params, fn->attrs); fn->type_params, fn->attrs);
} else if (tuple->fields.size() == new_body.size()) { } else if (tuple->fields.size() == new_body.size()) {
return new_expr; return new_expr;
} else { } else {
Tuple tuple_body = TupleNode::make(new_body); Tuple tuple_body = TupleNode::make(new_body);
return FunctionNode::make(params, tuple_body, Type(nullptr), return Function(params, tuple_body, Type(nullptr),
fn->type_params, fn->attrs); fn->type_params, fn->attrs);
} }
} else { } else {
......
...@@ -111,7 +111,7 @@ class EtaExpander : public ExprMutator { ...@@ -111,7 +111,7 @@ class EtaExpander : public ExprMutator {
Expr body = CallNode::make(cons, params, Attrs()); Expr body = CallNode::make(cons, params, Attrs());
Type ret_type = TypeCall(cons->belong_to, type_params); Type ret_type = TypeCall(cons->belong_to, type_params);
return FunctionNode::make( return Function(
Downcast<tvm::Array<Var>>(params), Downcast<tvm::Array<Var>>(params),
body, body,
ret_type, ret_type,
...@@ -135,7 +135,7 @@ class EtaExpander : public ExprMutator { ...@@ -135,7 +135,7 @@ class EtaExpander : public ExprMutator {
args.push_back(var); args.push_back(var);
} }
return FunctionNode::make( return Function(
args, args,
CallNode::make(gvar, params), CallNode::make(gvar, params),
func->ret_type, func->ret_type,
......
...@@ -209,7 +209,7 @@ class ConstantFolder : public ExprMutator { ...@@ -209,7 +209,7 @@ class ConstantFolder : public ExprMutator {
func = Downcast<Function>(expr); func = Downcast<Function>(expr);
} else { } else {
// TODO(@jroesch): fix this // TODO(@jroesch): fix this
func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {}); func = Function(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {});
} }
auto mod = IRModule( auto mod = IRModule(
{}, {},
......
...@@ -852,7 +852,7 @@ class FuseMutator : private ExprMutator { ...@@ -852,7 +852,7 @@ class FuseMutator : private ExprMutator {
// Skip primitive function. // Skip primitive function.
Expr VisitExpr_(const FunctionNode* fn_node) { Expr VisitExpr_(const FunctionNode* fn_node) {
if (fn_node->IsPrimitive()) { if (fn_node->HasNonzeroAttr(attr::kPrimitive)) {
return GetRef<Expr>(fn_node); return GetRef<Expr>(fn_node);
} else { } else {
return ExprMutator::VisitExpr_(fn_node); return ExprMutator::VisitExpr_(fn_node);
...@@ -932,8 +932,8 @@ class FuseMutator : private ExprMutator { ...@@ -932,8 +932,8 @@ class FuseMutator : private ExprMutator {
} visitor; } visitor;
visitor(body); visitor(body);
const GroupInfo& ginfo = ginfo_[group]; const GroupInfo& ginfo = ginfo_[group];
auto func = FunctionNode::make(ginfo.params, body, ret_type, {}); auto func = Function(ginfo.params, body, ret_type, {});
func = FunctionSetAttr(func, attr::kPrimitive, tvm::Integer(visitor.has_call)); func = WithAttr(std::move(func), attr::kPrimitive, tvm::Integer(visitor.has_call));
return CallNode::make(func, ginfo.arguments, Attrs()); return CallNode::make(func, ginfo.arguments, Attrs());
} }
......
...@@ -255,7 +255,7 @@ Expr FirstOrderGradient(const Expr& re, const IRModule& mod) { ...@@ -255,7 +255,7 @@ Expr FirstOrderGradient(const Expr& re, const IRModule& mod) {
return Pair(res.forward, grad); return Pair(res.forward, grad);
}); });
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {}); return Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
} }
TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient") TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient")
...@@ -384,7 +384,7 @@ void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { ...@@ -384,7 +384,7 @@ void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
} }
Expr BPEmpty() { Expr BPEmpty() {
Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleType::Empty(), {}); Expr unitF = Function({}, TupleNode::make({}), TupleType::Empty(), {});
return RefCreateNode::make(unitF); return RefCreateNode::make(unitF);
} }
...@@ -413,7 +413,7 @@ struct ReverseAD : ExprMutator { ...@@ -413,7 +413,7 @@ struct ReverseAD : ExprMutator {
auto x_var = ll->Push(x); auto x_var = ll->Push(x);
auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll)); auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
auto bpv = ll->Push(RefReadNode::make(bp)); auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make( Expr nbp = Function(
{}, {},
LetList::With([&](LetList* ll) { LetList::With([&](LetList* ll) {
// we need a new ReverseAD visitor to avoid clobbering the bp local var // we need a new ReverseAD visitor to avoid clobbering the bp local var
...@@ -457,7 +457,7 @@ struct ReverseAD : ExprMutator { ...@@ -457,7 +457,7 @@ struct ReverseAD : ExprMutator {
orig_var->checked_type_ = call->checked_type(); orig_var->checked_type_ = call->checked_type();
auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll)); auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
auto bpv = ll->Push(RefReadNode::make(bp)); auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make( Expr nbp = Function(
{}, {},
LetList::With([&](LetList* ll) { LetList::With([&](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll)); tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
...@@ -583,7 +583,7 @@ Expr Gradient(const Expr& re, const IRModule& mod) { ...@@ -583,7 +583,7 @@ Expr Gradient(const Expr& re, const IRModule& mod) {
}; };
return Pair(get_final_result(c, f->body->checked_type()), TupleNode::make(ret)); return Pair(get_final_result(c, f->body->checked_type()), TupleNode::make(ret));
}); });
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {}); return Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
} }
TVM_REGISTER_GLOBAL("relay._transform.gradient") TVM_REGISTER_GLOBAL("relay._transform.gradient")
......
...@@ -83,7 +83,7 @@ class Inliner : ExprMutator { ...@@ -83,7 +83,7 @@ class Inliner : ExprMutator {
} }
Function Inline(const Function& func) { Function Inline(const Function& func) {
return FunctionNode::make(func->params, return Function(func->params,
VisitExpr(func->body), VisitExpr(func->body),
func->ret_type, func->ret_type,
func->type_params, func->type_params,
...@@ -101,7 +101,7 @@ class Inliner : ExprMutator { ...@@ -101,7 +101,7 @@ class Inliner : ExprMutator {
if (!func->body.defined()) return false; if (!func->body.defined()) return false;
// The function must be annotated with the inline attribute. // The function must be annotated with the inline attribute.
if (!func->IsMarkedInline()) return false; if (!func->HasNonzeroAttr(attr::kInline)) return false;
// The function is not abled to be inlined if any callee under the CallGraph // The function is not abled to be inlined if any callee under the CallGraph
// of this function cannot be inlined. // of this function cannot be inlined.
...@@ -124,7 +124,7 @@ class Inliner : ExprMutator { ...@@ -124,7 +124,7 @@ class Inliner : ExprMutator {
const auto* fn = base_func.as<FunctionNode>(); const auto* fn = base_func.as<FunctionNode>();
CHECK(fn) << "Expected to work on a Relay function."; CHECK(fn) << "Expected to work on a Relay function.";
auto func = FunctionNode::make(fn->params, auto func = Function(fn->params,
fn->body, fn->body,
fn->ret_type, fn->ret_type,
fn->type_params, fn->type_params,
...@@ -198,7 +198,7 @@ IRModule Inline(const IRModule& module) { ...@@ -198,7 +198,7 @@ IRModule Inline(const IRModule& module) {
auto base_func = cg->GetGlobalFunction(cgn->GetGlobalVar()); auto base_func = cg->GetGlobalFunction(cgn->GetGlobalVar());
if (const auto* fn = base_func.as<FunctionNode>()) { if (const auto* fn = base_func.as<FunctionNode>()) {
auto func = GetRef<Function>(fn); auto func = GetRef<Function>(fn);
if (func->IsMarkedInline()) { if (func->HasNonzeroAttr(attr::kInline)) {
CHECK_EQ(cgn->GetRefCount(), 0U) CHECK_EQ(cgn->GetRefCount(), 0U)
<< cgn->GetNameHint() << " is marked as inline but not inlined."; << cgn->GetNameHint() << " is marked as inline but not inlined.";
cgn->CleanCallGraphEntries(); cgn->CleanCallGraphEntries();
......
...@@ -140,9 +140,10 @@ class MergeCompositeWrapper : public ExprMutator { ...@@ -140,9 +140,10 @@ class MergeCompositeWrapper : public ExprMutator {
if (call->op->IsInstance<FunctionNode>()) { if (call->op->IsInstance<FunctionNode>()) {
Function func = Downcast<Function>(call->op); Function func = Downcast<Function>(call->op);
CHECK(func.defined()); CHECK(func.defined());
const auto name_node = FunctionGetAttr(func, attr::kComposite).as<tir::StringImmNode>(); const auto name_node =
func->GetAttr<tir::StringImm>(attr::kComposite);
// don't step into existing composite functions // don't step into existing composite functions
if (name_node && name_node->value != "") { if (name_node.defined() && name_node->value != "") {
tvm::Array<tvm::relay::Expr> new_args; tvm::Array<tvm::relay::Expr> new_args;
for (const auto& arg : call->args) { for (const auto& arg : call->args) {
auto new_e = this->Mutate(arg); auto new_e = this->Mutate(arg);
...@@ -166,8 +167,8 @@ class MergeCompositeWrapper : public ExprMutator { ...@@ -166,8 +167,8 @@ class MergeCompositeWrapper : public ExprMutator {
if (extract.defined()) { if (extract.defined()) {
auto free_vars = FreeVars(extract); auto free_vars = FreeVars(extract);
// make the composite function // make the composite function
auto f = FunctionNode::make(free_vars, extract, call->checked_type_, {}, Attrs()); auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs());
f = FunctionSetAttr(f, attr::kComposite, tir::StringImmNode::make(pattern_name_)); f = WithAttr(std::move(f), attr::kComposite, tir::StringImmNode::make(pattern_name_));
// find the expressions associated with the free vars using the args_map // find the expressions associated with the free vars using the args_map
// this tells us which expressions should be given as inputs to the composite function // this tells us which expressions should be given as inputs to the composite function
Array<Expr> args; Array<Expr> args;
......
...@@ -820,7 +820,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -820,7 +820,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
Func VisitFuncStatic(const Function& func, const Expr& var) { Func VisitFuncStatic(const Function& func, const Expr& var) {
CHECK(IsAtomic(var)); CHECK(IsAtomic(var));
if (func->IsPrimitive()) { if (func->HasNonzeroAttr(attr::kPrimitive)) {
return ConstEvaluateFunc(func); return ConstEvaluateFunc(func);
} }
std::vector<std::pair<Var, PStatic> > free_vars; std::vector<std::pair<Var, PStatic> > free_vars;
...@@ -881,7 +881,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -881,7 +881,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) { Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) {
return store_.Extend<Expr>([&]() { return store_.Extend<Expr>([&]() {
store_.Invalidate(); store_.Invalidate();
return FunctionNode::make(func->params, return Function(func->params,
LetList::With([&](LetList* ll) { LetList::With([&](LetList* ll) {
std::vector<PStatic> pv; std::vector<PStatic> pv;
for (const auto& v : func->params) { for (const auto& v : func->params) {
......
...@@ -211,15 +211,18 @@ class Partitioner : public ExprMutator { ...@@ -211,15 +211,18 @@ class Partitioner : public ExprMutator {
} }
auto subgraph_func = auto subgraph_func =
FunctionNode::make(params, input, call->checked_type_, {}, Attrs()); Function(params, input, call->checked_type_, {});
std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id); std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id);
subgraph_func = subgraph_func =
FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tir::StringImmNode::make(name)); WithAttr(std::move(subgraph_func), attr::kExternalSymbol, tir::StringImmNode::make(name));
subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1)); subgraph_func =
subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler, WithAttr(std::move(subgraph_func), attr::kPrimitive, tvm::Integer(1));
tvm::tir::StringImmNode::make(compiler_attrs->compiler)); subgraph_func =
subgraph_func = FunctionSetAttr(subgraph_func, attr::kInline, tvm::Integer(1)); WithAttr(std::move(subgraph_func), attr::kCompiler,
tvm::tir::StringImmNode::make(compiler_attrs->compiler));
subgraph_func =
WithAttr(std::move(subgraph_func), attr::kInline, tvm::Integer(1));
CHECK(!module_->ContainGlobalVar(name)) CHECK(!module_->ContainGlobalVar(name))
<< "Global function " << name << " already exists"; << "Global function " << name << " already exists";
// Create a global function and add it to the IRModule for the subgraph. // Create a global function and add it to the IRModule for the subgraph.
...@@ -277,7 +280,7 @@ class Partitioner : public ExprMutator { ...@@ -277,7 +280,7 @@ class Partitioner : public ExprMutator {
params.push_back(new_param); params.push_back(new_param);
} }
auto body = VisitExpr(op->body); auto body = VisitExpr(op->body);
return FunctionNode::make(params, body, op->ret_type, op->type_params, op->attrs); return Function(params, body, op->ret_type, op->type_params, op->attrs);
} }
} }
...@@ -351,7 +354,7 @@ class Partitioner : public ExprMutator { ...@@ -351,7 +354,7 @@ class Partitioner : public ExprMutator {
for (const auto& pair : glob_funcs) { for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) { if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn); auto func = GetRef<Function>(fn);
func = FunctionNode::make(func->params, func = Function(func->params,
VisitExpr(func->body), VisitExpr(func->body),
func->ret_type, func->ret_type,
func->type_params, func->type_params,
......
...@@ -91,7 +91,7 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map); ...@@ -91,7 +91,7 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map);
*/ */
inline Expr TransformF(const std::function<Expr(const Expr&)>& func, const Expr& e) { inline Expr TransformF(const std::function<Expr(const Expr&)>& func, const Expr& e) {
if (const FunctionNode* f = e.as<FunctionNode>()) { if (const FunctionNode* f = e.as<FunctionNode>()) {
return FunctionNode::make(f->params, func(f->body), f->ret_type, f->type_params, f->attrs); return Function(f->params, func(f->body), f->ret_type, f->type_params, f->attrs);
} else { } else {
return func(e); return func(e);
} }
......
...@@ -208,10 +208,10 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { ...@@ -208,10 +208,10 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Expr VisitExpr_(const FunctionNode* f, const Var& v) final { Expr VisitExpr_(const FunctionNode* f, const Var& v) final {
Expr e = GetRef<Expr>(f); Expr e = GetRef<Expr>(f);
Expr ret; Expr ret;
if (f->IsPrimitive()) { if (f->HasNonzeroAttr(attr::kPrimitive)) {
ret = e; ret = e;
} else { } else {
ret = FunctionNode::make(f->params, ret = Function(f->params,
GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)), GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)),
f->ret_type, f->ret_type,
f->type_params, f->type_params,
......
...@@ -142,7 +142,7 @@ Function ToCPS(const Function& f, ...@@ -142,7 +142,7 @@ Function ToCPS(const Function& f,
} }
Expr VisitExpr_(const FunctionNode* op, const MCont& k) final { Expr VisitExpr_(const FunctionNode* op, const MCont& k) final {
CHECK(!op->IsPrimitive()) << "primitive func not supported yet."; CHECK(!op->HasNonzeroAttr(attr::kPrimitive)) << "primitive func not supported yet.";
return k(ToCPS(GetRef<Function>(op), m, cm, vm, answer)); return k(ToCPS(GetRef<Function>(op), m, cm, vm, answer));
} }
...@@ -182,7 +182,7 @@ Function ToCPS(const Function& f, ...@@ -182,7 +182,7 @@ Function ToCPS(const Function& f,
Expr reify(const MCont& k) { Expr reify(const MCont& k) {
Var arg = VarNode::make("arg", Type()); Var arg = VarNode::make("arg", Type());
return FunctionNode::make({arg}, k(arg), Type(), {}, {}); return Function({arg}, k(arg), Type(), {}, {});
} }
Expr reify(const MCont& k, const std::function<Expr(MCont)>& cont) { Expr reify(const MCont& k, const std::function<Expr(MCont)>& cont) {
...@@ -293,7 +293,7 @@ Function ToCPS(const Function& f, ...@@ -293,7 +293,7 @@ Function ToCPS(const Function& f,
new_params.push_back(remap(v)); new_params.push_back(remap(v));
} }
new_params.push_back(k); new_params.push_back(k);
return FunctionNode::make(new_params, return Function(new_params,
mut.VisitExpr(f->body, mut.VisitExpr(f->body,
[&](const Expr& e) { return CallNode::make(k, {e}); }), [&](const Expr& e) { return CallNode::make(k, {e}); }),
answer, answer,
...@@ -328,7 +328,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { ...@@ -328,7 +328,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) {
Function ret = ToCPS(f, m, cm, &var, answer); Function ret = ToCPS(f, m, cm, &var, answer);
auto new_type_params = ret->type_params; auto new_type_params = ret->type_params;
new_type_params.push_back(answer); new_type_params.push_back(answer);
return FunctionNode::make(ret->params, ret->body, ret->ret_type, new_type_params, ret->attrs); return Function(ret->params, ret->body, ret->ret_type, new_type_params, ret->attrs);
} }
Function ToCPS(const Function& f, const IRModule& m) { Function ToCPS(const Function& f, const IRModule& m) {
...@@ -355,7 +355,7 @@ Function UnCPS(const Function& f) { ...@@ -355,7 +355,7 @@ Function UnCPS(const Function& f) {
// TODO(@M.K.): make alphaequal work on free term // TODO(@M.K.): make alphaequal work on free term
// CHECK(AlphaEqual(cont_type, Arrow(new_ret_type, answer_type))); // CHECK(AlphaEqual(cont_type, Arrow(new_ret_type, answer_type)));
auto x = VarNode::make("x", new_ret_type); auto x = VarNode::make("x", new_ret_type);
auto cont = FunctionNode::make({x}, x, new_ret_type, {}, {}); auto cont = Function({x}, x, new_ret_type, {}, {});
tvm::Array<Expr> args; tvm::Array<Expr> args;
for (const auto& p : new_params) { for (const auto& p : new_params) {
args.push_back(p); args.push_back(p);
...@@ -366,7 +366,7 @@ Function UnCPS(const Function& f) { ...@@ -366,7 +366,7 @@ Function UnCPS(const Function& f) {
type_args.push_back(tp); type_args.push_back(tp);
} }
type_args.push_back(new_ret_type); type_args.push_back(new_ret_type);
return FunctionNode::make(new_params, return Function(new_params,
CallNode::make(f, args, {}, type_args), CallNode::make(f, args, {}, type_args),
new_ret_type, new_ret_type,
new_type_params, new_type_params,
......
...@@ -82,7 +82,7 @@ TEST(Relay, BuildModule) { ...@@ -82,7 +82,7 @@ TEST(Relay, BuildModule) {
auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {}); auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {});
auto c = relay::VarNode::make("c", tensor_type); auto c = relay::VarNode::make("c", tensor_type);
auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {}); auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {});
auto func = relay::FunctionNode::make(relay::FreeVars(y), y, relay::Type(), {}); auto func = relay::Function(relay::FreeVars(y), y, relay::Type(), {});
auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
......
...@@ -28,11 +28,11 @@ TEST(Relay, SelfReference) { ...@@ -28,11 +28,11 @@ TEST(Relay, SelfReference) {
using namespace tvm; using namespace tvm;
auto tensor_type = relay::TensorType({}, DataType::Bool()); auto tensor_type = relay::TensorType({}, DataType::Bool());
auto x = relay::VarNode::make("x", relay::Type()); auto x = relay::VarNode::make("x", relay::Type());
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {}); auto f = relay::Function(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
CHECK(f->IsInstance<BaseFuncNode>()); CHECK(f->IsInstance<BaseFuncNode>());
auto y = relay::VarNode::make("y", tensor_type); auto y = relay::VarNode::make("y", tensor_type);
auto call = relay::CallNode::make(f, Array<relay::Expr>{ y }); auto call = relay::CallNode::make(f, Array<relay::Expr>{ y });
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {}); auto fx = relay::Function(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
auto mod = IRModule::FromExpr(fx); auto mod = IRModule::FromExpr(fx);
mod = relay::transform::InferType()(mod); mod = relay::transform::InferType()(mod);
auto type_fx = mod->Lookup("main"); auto type_fx = mod->Lookup("main");
......
...@@ -53,7 +53,7 @@ TEST(Relay, Sequential) { ...@@ -53,7 +53,7 @@ TEST(Relay, Sequential) {
// Let expression and varaible a should be dead-code eliminated. // Let expression and varaible a should be dead-code eliminated.
auto z3 = relay::LetNode::make(a, c, z2); auto z3 = relay::LetNode::make(a, c, z2);
relay::Function func = relay::Function func =
relay::FunctionNode::make(relay::FreeVars(z3), z3, relay::Type(), {}); relay::Function(relay::FreeVars(z3), z3, relay::Type(), {});
// Get schedule // Get schedule
auto reg = tvm::runtime::Registry::Get("relay.op._Register"); auto reg = tvm::runtime::Registry::Get("relay.op._Register");
...@@ -96,7 +96,7 @@ TEST(Relay, Sequential) { ...@@ -96,7 +96,7 @@ TEST(Relay, Sequential) {
auto zz = relay::CallNode::make(add_op, {y1, c1}); auto zz = relay::CallNode::make(add_op, {y1, c1});
zz = relay::CallNode::make(add_op, {zz, zz}); zz = relay::CallNode::make(add_op, {zz, zz});
relay::Function expected_func = relay::Function expected_func =
relay::FunctionNode::make(relay::FreeVars(zz), zz, relay::Type(), {}); relay::Function(relay::FreeVars(zz), zz, relay::Type(), {});
// Infer type for the expected function. // Infer type for the expected function.
auto mod1 = IRModule::FromExpr(expected_func); auto mod1 = IRModule::FromExpr(expected_func);
......
...@@ -58,7 +58,7 @@ TEST(MicroStandaloneRuntime, BuildModule) { ...@@ -58,7 +58,7 @@ TEST(MicroStandaloneRuntime, BuildModule) {
auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {}); auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {});
auto c = relay::VarNode::make("c", tensor_type); auto c = relay::VarNode::make("c", tensor_type);
auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {}); auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {});
auto func = relay::FunctionNode::make(relay::FreeVars(y), y, relay::Type(), {}); auto func = relay::Function(relay::FreeVars(y), y, relay::Type(), {});
auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0}); auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
......
...@@ -134,7 +134,7 @@ def test_recursive_func(): ...@@ -134,7 +134,7 @@ def test_recursive_func():
func = relay.Function([i], func = relay.Function([i],
sb.get(), sb.get(),
ret_type=relay.TensorType([], 'int32')) ret_type=relay.TensorType([], 'int32'))
func = func.set_attribute("Compiler", tvm.tir.StringImm("a")) func = func.with_attr("Compiler", tvm.tir.StringImm("a"))
mod[sum_up] = func mod[sum_up] = func
iarg = relay.var('i', shape=[], dtype='int32') iarg = relay.var('i', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg], sum_up(iarg)) mod["main"] = relay.Function([iarg], sum_up(iarg))
......
...@@ -78,9 +78,9 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ...@@ -78,9 +78,9 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
def set_external_func_attr(func, compiler, ext_symbol): def set_external_func_attr(func, compiler, ext_symbol):
func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler", tvm.tir.StringImm(compiler)) func = func.with_attr("Compiler", tvm.tir.StringImm(compiler))
func = func.set_attribute("ExternalSymbol", tvm.tir.StringImm(ext_symbol)) func = func.with_attr("ExternalSymbol", tvm.tir.StringImm(ext_symbol))
return func return func
......
...@@ -307,7 +307,7 @@ def get_synthetic_lib(): ...@@ -307,7 +307,7 @@ def get_synthetic_lib():
gcc_input3 = relay.var('gcc_input3', shape=(10, 10)) gcc_input3 = relay.var('gcc_input3', shape=(10, 10))
subgraph0 = relay.Function([gcc_input0, gcc_input1, gcc_input2, subgraph0 = relay.Function([gcc_input0, gcc_input1, gcc_input2,
gcc_input3], relay.copy(gcc_input0)) gcc_input3], relay.copy(gcc_input0))
subgraph0 = subgraph0.set_attribute( subgraph0 = subgraph0.with_attr(
"Primitive", tvm.tir.IntImm("int32", 1)) "Primitive", tvm.tir.IntImm("int32", 1))
# Call subgraph0 # Call subgraph0
...@@ -320,7 +320,7 @@ def get_synthetic_lib(): ...@@ -320,7 +320,7 @@ def get_synthetic_lib():
gcc_input7 = relay.var('gcc_input7', shape=(10, 10)) gcc_input7 = relay.var('gcc_input7', shape=(10, 10))
subgraph1 = relay.Function([gcc_input4, gcc_input5, gcc_input6, subgraph1 = relay.Function([gcc_input4, gcc_input5, gcc_input6,
gcc_input7], relay.copy(gcc_input4)) gcc_input7], relay.copy(gcc_input4))
subgraph1 = subgraph1.set_attribute( subgraph1 = subgraph1.with_attr(
"Primitive", tvm.tir.IntImm("int32", 1)) "Primitive", tvm.tir.IntImm("int32", 1))
# Call subgraph1 # Call subgraph1
......
...@@ -169,15 +169,16 @@ def test_function(): ...@@ -169,15 +169,16 @@ def test_function():
body = relay.Tuple(tvm.runtime.convert([])) body = relay.Tuple(tvm.runtime.convert([]))
type_params = tvm.runtime.convert([]) type_params = tvm.runtime.convert([])
fn = relay.Function(params, body, ret_type, type_params) fn = relay.Function(params, body, ret_type, type_params)
fn = fn.set_attribute("test_attribute", tvm.tir.StringImm("value")) fn = fn.with_attr("test_attribute", tvm.tir.StringImm("value"))
assert fn.params == params assert fn.params == params
assert fn.body == body assert fn.body == body
assert fn.type_params == type_params assert fn.type_params == type_params
assert fn.span == None assert fn.span == None
assert fn.get_attribute("test_attribute") == "value" assert fn.attrs["test_attribute"] == "value"
str(fn) str(fn)
check_json_roundtrip(fn) check_json_roundtrip(fn)
@pytest.mark.skip(reason="AttrsEqualHandler doesn't handle Map so far.") @pytest.mark.skip(reason="AttrsEqualHandler doesn't handle Map so far.")
def test_function_attrs(): def test_function_attrs():
param_names = ['a', 'b', 'c', 'd'] param_names = ['a', 'b', 'c', 'd']
...@@ -190,8 +191,10 @@ def test_function_attrs(): ...@@ -190,8 +191,10 @@ def test_function_attrs():
for param in params[:1]: for param in params[:1]:
cty = param.type_annotation cty = param.type_annotation
tensor = np.random.rand(*[int(sh) for sh in cty.shape]).astype(cty.dtype) tensor = np.random.rand(*[int(sh) for sh in cty.shape]).astype(cty.dtype)
model_params[param] = tvm.nd.array(tensor) model_params[param] = relay.Constant(tvm.nd.array(tensor))
fn = fn.set_params(model_params)
fn = fn.with_attr("__params__", model_params)
assert fn.params == params assert fn.params == params
assert fn.body == body assert fn.body == body
assert fn.type_params == type_params assert fn.type_params == type_params
...@@ -200,7 +203,7 @@ def test_function_attrs(): ...@@ -200,7 +203,7 @@ def test_function_attrs():
check_json_roundtrip(fn) check_json_roundtrip(fn)
json_str = tvm.ir.save_json(fn) json_str = tvm.ir.save_json(fn)
fn_after = tvm.ir.load_json(json_str) fn_after = tvm.ir.load_json(json_str)
model_params_after = fn_after.get_params() model_params_after = fn_after.attrs["__params__"]
after_keys = [item[0] for item in model_params_after.items()] after_keys = [item[0] for item in model_params_after.items()]
for key1, key2 in zip(model_params, after_keys): for key1, key2 in zip(model_params, after_keys):
assert key1.name_hint == key2.name_hint assert key1.name_hint == key2.name_hint
...@@ -296,4 +299,3 @@ if __name__ == "__main__": ...@@ -296,4 +299,3 @@ if __name__ == "__main__":
test_tuple_get_item() test_tuple_get_item()
test_op() test_op()
test_conv2d_attrs() test_conv2d_attrs()
...@@ -353,7 +353,7 @@ def test_function_attr(): ...@@ -353,7 +353,7 @@ def test_function_attr():
p00 = relay.subtract(z00, w01) p00 = relay.subtract(z00, w01)
q00 = relay.multiply(p00, w02) q00 = relay.multiply(p00, w02)
func0 = relay.Function([x0, w00, w01, w02], q00) func0 = relay.Function([x0, w00, w01, w02], q00)
func0 = func0.set_attribute("FuncName", tvm.tir.StringImm("a")) func0 = func0.with_attr("FuncName", tvm.tir.StringImm("a"))
x1 = relay.var('x1', shape=(10, 10)) x1 = relay.var('x1', shape=(10, 10))
w10 = relay.var('w10', shape=(10, 10)) w10 = relay.var('w10', shape=(10, 10))
...@@ -363,7 +363,7 @@ def test_function_attr(): ...@@ -363,7 +363,7 @@ def test_function_attr():
p10 = relay.subtract(z10, w11) p10 = relay.subtract(z10, w11)
q10 = relay.multiply(p10, w12) q10 = relay.multiply(p10, w12)
func1 = relay.Function([x1, w10, w11, w12], q10) func1 = relay.Function([x1, w10, w11, w12], q10)
func1 = func1.set_attribute("FuncName", tvm.tir.StringImm("b")) func1 = func1.with_attr("FuncName", tvm.tir.StringImm("b"))
assert not alpha_equal(func0, func1) assert not alpha_equal(func0, func1)
...@@ -694,7 +694,7 @@ def test_fn_attribute(): ...@@ -694,7 +694,7 @@ def test_fn_attribute():
d = relay.var('d', shape=(10, 10)) d = relay.var('d', shape=(10, 10))
add_1 = relay.add(c, d) add_1 = relay.add(c, d)
add_1_fn = relay.Function([c, d], add_1) add_1_fn = relay.Function([c, d], add_1)
add_1_fn = add_1_fn.set_attribute("TestAttribute", tvm.tir.StringImm("test")) add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.tir.StringImm("test"))
add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType()) add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType())
assert not relay.analysis.alpha_equal(add_1_fn, add_fn) assert not relay.analysis.alpha_equal(add_1_fn, add_fn)
......
...@@ -164,7 +164,7 @@ def test_simple_merge(): ...@@ -164,7 +164,7 @@ def test_simple_merge():
add_node = relay.add(in_1, in_2) add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node) relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node) add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu")) add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu"))
# merged function # merged function
r = relay.Call(add_relu, [a, b]) r = relay.Call(add_relu, [a, b])
...@@ -229,7 +229,7 @@ def test_branch_merge(): ...@@ -229,7 +229,7 @@ def test_branch_merge():
sub_node = relay.subtract(in_1, in_2) sub_node = relay.subtract(in_1, in_2)
mul_node = relay.multiply(add_node, sub_node) mul_node = relay.multiply(add_node, sub_node)
add_sub_mul = relay.Function([in_1, in_2], mul_node) add_sub_mul = relay.Function([in_1, in_2], mul_node)
add_sub_mul = add_sub_mul.set_attribute("Composite", add_sub_mul = add_sub_mul.with_attr("Composite",
tir.StringImm("add_sub_mul")) tir.StringImm("add_sub_mul"))
# add_sub_mul1 function # add_sub_mul1 function
...@@ -239,7 +239,7 @@ def test_branch_merge(): ...@@ -239,7 +239,7 @@ def test_branch_merge():
sub_node_1 = relay.subtract(in_3, in_4) sub_node_1 = relay.subtract(in_3, in_4)
mul_node_1 = relay.multiply(add_node_1, sub_node_1) mul_node_1 = relay.multiply(add_node_1, sub_node_1)
add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1) add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1)
add_sub_mul_1 = add_sub_mul_1.set_attribute("Composite", add_sub_mul_1 = add_sub_mul_1.with_attr("Composite",
tir.StringImm("add_sub_mul")) tir.StringImm("add_sub_mul"))
# merged function # merged function
...@@ -299,7 +299,7 @@ def test_reuse_call_merge(): ...@@ -299,7 +299,7 @@ def test_reuse_call_merge():
add_node_1 = relay.add(in_1, add_node) add_node_1 = relay.add(in_1, add_node)
add_node_2 = relay.add(add_node_1, add_node) add_node_2 = relay.add(add_node_1, add_node)
add_add_add = relay.Function([in_1, in_2], add_node_2) add_add_add = relay.Function([in_1, in_2], add_node_2)
add_add_add = add_add_add.set_attribute("Composite", add_add_add = add_add_add.with_attr("Composite",
tir.StringImm("add_add_add")) tir.StringImm("add_add_add"))
# merged function # merged function
...@@ -383,7 +383,7 @@ def test_multiple_patterns(): ...@@ -383,7 +383,7 @@ def test_multiple_patterns():
bias_node = relay.nn.bias_add(conv_node, in_3) bias_node = relay.nn.bias_add(conv_node, in_3)
r = relay.nn.relu(bias_node) r = relay.nn.relu(bias_node)
conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r) conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
conv_bias_add_relu = conv_bias_add_relu.set_attribute("Composite", conv_bias_add_relu = conv_bias_add_relu.with_attr("Composite",
tir.StringImm("conv2d_bias_relu")) tir.StringImm("conv2d_bias_relu"))
# add_relu function # add_relu function
...@@ -392,7 +392,7 @@ def test_multiple_patterns(): ...@@ -392,7 +392,7 @@ def test_multiple_patterns():
add_node = relay.add(in_4, in_5) add_node = relay.add(in_4, in_5)
r = relay.nn.relu(add_node) r = relay.nn.relu(add_node)
add_relu = relay.Function([in_4, in_5], r) add_relu = relay.Function([in_4, in_5], r)
add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu")) add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu"))
# merged function # merged function
conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias]) conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias])
...@@ -461,7 +461,7 @@ def test_merge_order(): ...@@ -461,7 +461,7 @@ def test_merge_order():
out = relay.abs(out) out = relay.abs(out)
out = relay.nn.relu(out) out = relay.nn.relu(out)
merged_func = relay.Function([x, y], out) merged_func = relay.Function([x, y], out)
merged_func = merged_func.set_attribute('Composite', merged_func = merged_func.with_attr('Composite',
tir.StringImm(composite_name)) tir.StringImm(composite_name))
ret = relay.Call(merged_func, [input_1, input_2]) ret = relay.Call(merged_func, [input_1, input_2])
return relay.Function([input_1, input_2], ret) return relay.Function([input_1, input_2], ret)
...@@ -527,13 +527,13 @@ def test_parallel_merge(): ...@@ -527,13 +527,13 @@ def test_parallel_merge():
y = relay.var('y') y = relay.var('y')
branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y)) branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
func_1 = relay.Function([x, y], branch_1) func_1 = relay.Function([x, y], branch_1)
func_1 = func_1.set_attribute('Composite', tir.StringImm("add_sub_mul")) func_1 = func_1.with_attr('Composite', tir.StringImm("add_sub_mul"))
call_1 = relay.Call(func_1, [input_1, input_2]) call_1 = relay.Call(func_1, [input_1, input_2])
x1 = relay.var('x1') x1 = relay.var('x1')
y1 = relay.var('y1') y1 = relay.var('y1')
branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1)) branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
func_2 = relay.Function([x1, y1], branch_2) func_2 = relay.Function([x1, y1], branch_2)
func_2 = func_2.set_attribute('Composite', tir.StringImm("add_sub_mul")) func_2 = func_2.with_attr('Composite', tir.StringImm("add_sub_mul"))
call_2 = relay.Call(func_2, [input_1, input_2]) call_2 = relay.Call(func_2, [input_1, input_2])
out = relay.multiply(call_1, call_2) out = relay.multiply(call_1, call_2)
return relay.Function([input_1, input_2], out) return relay.Function([input_1, input_2], out)
...@@ -612,14 +612,14 @@ def test_multiple_input_subgraphs(): ...@@ -612,14 +612,14 @@ def test_multiple_input_subgraphs():
add_relu_1 = relay.add(x, y) add_relu_1 = relay.add(x, y)
add_relu_1 = relay.nn.relu(add_relu_1) add_relu_1 = relay.nn.relu(add_relu_1)
add_relu_1 = relay.Function([x, y], add_relu_1) add_relu_1 = relay.Function([x, y], add_relu_1)
add_relu_1 = add_relu_1.set_attribute('Composite', tir.StringImm('add_relu')) add_relu_1 = add_relu_1.with_attr('Composite', tir.StringImm('add_relu'))
add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]]) add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
x1 = relay.var('x1') x1 = relay.var('x1')
y1 = relay.var('y1') y1 = relay.var('y1')
add_relu_2 = relay.add(x1, y1) add_relu_2 = relay.add(x1, y1)
add_relu_2 = relay.nn.relu(add_relu_2) add_relu_2 = relay.nn.relu(add_relu_2)
add_relu_2 = relay.Function([x1, y1], add_relu_2) add_relu_2 = relay.Function([x1, y1], add_relu_2)
add_relu_2 = add_relu_2.set_attribute('Composite', tir.StringImm('add_relu')) add_relu_2 = add_relu_2.with_attr('Composite', tir.StringImm('add_relu'))
add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]]) add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
x2 = relay.var('x2') x2 = relay.var('x2')
y2 = relay.var('y2') y2 = relay.var('y2')
...@@ -627,7 +627,7 @@ def test_multiple_input_subgraphs(): ...@@ -627,7 +627,7 @@ def test_multiple_input_subgraphs():
sub = relay.subtract(x2, y2) sub = relay.subtract(x2, y2)
add_sub_mul = relay.multiply(add, sub) add_sub_mul = relay.multiply(add, sub)
add_sub_mul = relay.Function([x2, y2], add_sub_mul) add_sub_mul = relay.Function([x2, y2], add_sub_mul)
add_sub_mul = add_sub_mul.set_attribute('Composite', tir.StringImm('add_sub_mul')) add_sub_mul = add_sub_mul.with_attr('Composite', tir.StringImm('add_sub_mul'))
add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2]) add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2])
return relay.Function(inputs, add_sub_mul_call) return relay.Function(inputs, add_sub_mul_call)
...@@ -640,7 +640,7 @@ def test_multiple_input_subgraphs(): ...@@ -640,7 +640,7 @@ def test_multiple_input_subgraphs():
add_relu = relay.add(x, y) add_relu = relay.add(x, y)
add_relu = relay.nn.relu(add_relu) add_relu = relay.nn.relu(add_relu)
add_relu = relay.Function([x, y], add_relu) add_relu = relay.Function([x, y], add_relu)
add_relu = add_relu.set_attribute('Composite', tir.StringImm('add_relu')) add_relu = add_relu.with_attr('Composite', tir.StringImm('add_relu'))
add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]]) add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]])
add_relu_calls.append(add_relu_call) add_relu_calls.append(add_relu_call)
......
...@@ -303,11 +303,11 @@ def test_extern_ccompiler_default_ops(): ...@@ -303,11 +303,11 @@ def test_extern_ccompiler_default_ops():
add = x0 + y0 add = x0 + y0
# Function that uses C compiler # Function that uses C compiler
func = relay.Function([x0, y0], add) func = relay.Function([x0, y0], add)
func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler", func = func.with_attr("Compiler",
tvm.tir.StringImm("ccompiler")) tvm.tir.StringImm("ccompiler"))
func = func.set_attribute("ExternalSymbol", func = func.with_attr("ExternalSymbol",
tvm.tir.StringImm("ccompiler_0")) tvm.tir.StringImm("ccompiler_0"))
glb_0 = relay.GlobalVar("ccompiler_0") glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func mod[glb_0] = func
...@@ -318,7 +318,7 @@ def test_extern_ccompiler_default_ops(): ...@@ -318,7 +318,7 @@ def test_extern_ccompiler_default_ops():
exp = relay.exp(p0) exp = relay.exp(p0)
concat = relay.concatenate([log, exp], axis=0) concat = relay.concatenate([log, exp], axis=0)
fused_func = relay.Function([p0], concat) fused_func = relay.Function([p0], concat)
fused_func = fused_func.set_attribute("Primitive", fused_func = fused_func.with_attr("Primitive",
tvm.tir.IntImm("int32", 1)) tvm.tir.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call]) fused_call = relay.Call(fused_func, [add_call])
main = relay.Function([x, y], fused_call) main = relay.Function([x, y], fused_call)
...@@ -390,10 +390,10 @@ def test_extern_dnnl(): ...@@ -390,10 +390,10 @@ def test_extern_dnnl():
out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
func = relay.Function([data0, input0, input1], out) func = relay.Function([data0, input0, input1], out)
func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler", tvm.tir.StringImm("dnnl")) func = func.with_attr("Compiler", tvm.tir.StringImm("dnnl"))
func = func.set_attribute("ExternalSymbol", func = func.with_attr("ExternalSymbol",
tvm.tir.StringImm("dnnl_0")) tvm.tir.StringImm("dnnl_0"))
glb_var = relay.GlobalVar("dnnl_0") glb_var = relay.GlobalVar("dnnl_0")
mod = tvm.IRModule() mod = tvm.IRModule()
...@@ -516,11 +516,11 @@ def test_function_lifting(): ...@@ -516,11 +516,11 @@ def test_function_lifting():
bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
bn.astuple()) bn.astuple())
func0 = func0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.set_attribute("Compiler", func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_compiler")) tvm.tir.StringImm("test_compiler"))
func0 = func0.set_attribute("ExternalSymbol", func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_compiler_0")) tvm.tir.StringImm("test_compiler_0"))
gv0 = relay.GlobalVar("test_compiler_0") gv0 = relay.GlobalVar("test_compiler_0")
mod[gv0] = func0 mod[gv0] = func0
...@@ -535,11 +535,11 @@ def test_function_lifting(): ...@@ -535,11 +535,11 @@ def test_function_lifting():
channels=16, channels=16,
padding=(1, 1)) padding=(1, 1))
func1 = relay.Function([data1, weight1], conv) func1 = relay.Function([data1, weight1], conv)
func1 = func1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func1 = func1.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.set_attribute("Compiler", func1 = func1.with_attr("Compiler",
tvm.tir.StringImm("test_compiler")) tvm.tir.StringImm("test_compiler"))
func1 = func1.set_attribute("ExternalSymbol", func1 = func1.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_compiler_1")) tvm.tir.StringImm("test_compiler_1"))
gv1 = relay.GlobalVar("test_compiler_1") gv1 = relay.GlobalVar("test_compiler_1")
mod[gv1] = func1 mod[gv1] = func1
...@@ -609,11 +609,11 @@ def test_function_lifting_inline(): ...@@ -609,11 +609,11 @@ def test_function_lifting_inline():
bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar) bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar], func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
bn.astuple()) bn.astuple())
func0 = func0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.set_attribute("Compiler", func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_compiler")) tvm.tir.StringImm("test_compiler"))
func0 = func0.set_attribute("ExternalSymbol", func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_compiler_0")) tvm.tir.StringImm("test_compiler_0"))
# main function # main function
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
def test_make_attrs():
try:
x = tvm.ir.make_node("attrs.TestAttrs", unknown_key=1, name="xx")
assert False
except tvm.error.TVMError as e:
assert str(e).find("unknown_key") != -1
try:
x = tvm.ir.make_node("attrs.TestAttrs", axis=100, name="xx")
assert False
except tvm.error.TVMError as e:
assert str(e).find("upper bound") != -1
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4))
assert x.name == "xx"
assert x.padding[0].value == 3
assert x.padding[1].value == 4
assert x.axis == 10
def test_dict_attrs():
dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert dattr.x.value == 1
datrr = tvm.ir.load_json(tvm.ir.save_json(dattr))
assert dattr.name.value == "xyz"
assert isinstance(dattr, tvm.ir.DictAttrs)
assert "name" in dattr
assert dattr["x"].value == 1
assert len(dattr) == 4
assert len([x for x in dattr.keys()]) == 4
assert len(dattr.items()) == 4
if __name__ == "__main__":
test_make_attrs()
test_dict_attrs()
...@@ -54,33 +54,6 @@ def test_make_node(): ...@@ -54,33 +54,6 @@ def test_make_node():
assert AA.value_index == A.value_index assert AA.value_index == A.value_index
def test_make_attrs():
try:
x = tvm.ir.make_node("attrs.TestAttrs", unknown_key=1, name="xx")
assert False
except tvm.error.TVMError as e:
assert str(e).find("unknown_key") != -1
try:
x = tvm.ir.make_node("attrs.TestAttrs", axis=100, name="xx")
assert False
except tvm.error.TVMError as e:
assert str(e).find("upper bound") != -1
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4))
assert x.name == "xx"
assert x.padding[0].value == 3
assert x.padding[1].value == 4
assert x.axis == 10
dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert dattr.x.value == 1
datrr = tvm.ir.load_json(tvm.ir.save_json(dattr))
assert dattr.name.value == "xyz"
def test_make_sum(): def test_make_sum():
A = te.placeholder((2, 10), name='A') A = te.placeholder((2, 10), name='A')
k = te.reduce_axis((0,10), "k") k = te.reduce_axis((0,10), "k")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment