Unverified Commit ec86d7f1 by Tianqi Chen Committed by GitHub

[REFACTOR] Streamline Function Attr interface. (#5045)

* [REFACTOR] Streamline Function Attr interface.

There has been quite a few recent changes that depends heavily on
the function attr interface. This PR streamlines that interface by introducing
two APIs that covers most of the usages.

- GetAttr which gets a typed object for a given key
  - HasNonzeroAttr is a quick helper that calls GetAttr to quickly check an attribute
- WithAttr that creates a new function object with the given attr
  - The API comes with copy on write optimization to avoid multiple copies
  - We deliberately pick the prefix With(instead of Set) to indicate this
    function does not mutate the original input.

On the python side:
- We allow read access via func.attrs (which is a DictAttr)
- func.with_attrs to create a new instance with updated attrs.

We also get rid of the small wrapper functions and make sure the API centered around
the GetAttr and HasNonzeroAttr interface.

This PR also changes the function construction to follow the new convention.

* Address review comments

* Address review comments

* Fix doxygen path
parent a9505365
......@@ -277,26 +277,13 @@ class BaseAttrsNode : public Object {
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
};
/*! \brief Base attribute container for all attributes */
/*!
* \brief Managed reference to BaseAttrsNode.
* \sa AttrsNode, BaseAttrsNode
*/
class Attrs : public ObjectRef {
public:
// normal constructor
Attrs() {}
// construct from shared ptr.
explicit Attrs(ObjectPtr<Object> n) : ObjectRef(n) {}
/*! \return The attribute node */
const BaseAttrsNode* operator->() const {
return ptr();
}
/*! \brief specify container node */
using ContainerType = BaseAttrsNode;
private:
/*! \return the internal attribute node */
const BaseAttrsNode* ptr() const {
return static_cast<const BaseAttrsNode*>(get());
}
TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode);
};
/*!
......@@ -309,12 +296,7 @@ class DictAttrsNode : public BaseAttrsNode {
public:
/*! \brief internal attrs map */
Map<std::string, ObjectRef> dict;
/*!
* \brief Consruct a Attrs backed by DictAttrsNode.
* \param dict The attributes.
* \return The dict attributes.
*/
TVM_DLL static Attrs make(Map<std::string, ObjectRef> dict);
// implementations
void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final;
......@@ -327,6 +309,23 @@ class DictAttrsNode : public BaseAttrsNode {
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
};
/*!
* \brief Managed reference to DictAttrsNode
* \sa DictAttrsNode.
*/
class DictAttrs : public Attrs {
public:
/*!
* \brief Consruct a Attrs backed by DictAttrsNode.
* \param dict The attributes.
* \return The dict attributes.
*/
TVM_DLL explicit DictAttrs(Map<std::string, ObjectRef> dict);
TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
};
// Namespace containing detail implementations
namespace detail {
......
......@@ -211,30 +211,6 @@ class GlobalVar : public RelayExpr {
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
};
/*!
* \brief Base node of all functions.
*
* We support several variants of functions throughout the stack.
* All of the functions shares the same type system(via checked_type)
* to support cross variant calls.
*
* \sa BaseFunc
*/
class BaseFuncNode : public RelayExprNode {
public:
static constexpr const char* _type_key = "BaseFunc";
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
};
/*!
* \brief Managed reference to BaseFuncNode.
* \sa BaseFuncNode
*/
class BaseFunc : public RelayExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};
// PrimExprs that are useful as runtime containers.
//
/*!
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ir/function.h
* \brief Function nodes.
*/
#ifndef TVM_IR_FUNCTION_H_
#define TVM_IR_FUNCTION_H_
#include <tvm/ir/expr.h>
#include <tvm/ir/attrs.h>
#include <type_traits>
#include <string>
namespace tvm {
/*!
* \brief Base node of all functions.
*
* We support several variants of functions throughout the stack.
* All of the functions share the same type system(via checked_type)
* to support cross variant calls.
*
* \sa BaseFunc
*/
class BaseFuncNode : public RelayExprNode {
public:
/*! \brief Additional attributes storing the meta-data */
DictAttrs attrs;
/*!
* \brief Get a function attribute.
*
* \param attr_key The attribute key.
* \param default_value The default value if the key does not exist, defaults to nullptr.
*
* \return The result
*
* \tparam TOBjectRef the expected object type.
* \throw Error if the key exists but the value does not match TObjectRef
*
* \code
*
* void GetAttrExample(const BaseFunc& f) {
* Integer value = f->GetAttr<Integer>("AttrKey", 0);
* }
*
* \endcode
*/
template<typename TObjectRef>
TObjectRef GetAttr(const std::string& attr_key,
TObjectRef default_value = NullValue<TObjectRef>()) const {
static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
"Can only call GetAttr with ObjectRef types.");
if (!attrs.defined()) return default_value;
auto it = attrs->dict.find(attr_key);
if (it != attrs->dict.end()) {
return Downcast<TObjectRef>((*it).second);
} else {
return default_value;
}
}
/*!
* \brief Check whether the function has an non-zero integer attr.
*
* This function can be used to check whether an optional
* attribute mark(e.g. inline) exists.
*
* \param attr_key The key to the attribute.
* \return The check result.
*
* \code
*
* void HasNonzeroAttrExample(const BaseFunc& f) {
* if (f->HasNonzeroAttr(attr::kInline)) {
* // inline the function.
* }
* }
*
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const {
return GetAttr<Integer>(attr_key, 0)->value != 0;
}
static constexpr const char* _type_key = "BaseFunc";
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
};
/*!
* \brief Managed reference to BaseFuncNode.
* \sa BaseFuncNode
*/
class BaseFunc : public RelayExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
......@@ -26,6 +26,7 @@
#include <tvm/ir/type.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/ir/adt.h>
#include <string>
......
......@@ -26,6 +26,7 @@
#include <tvm/relay/adt.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/ir/module.h>
#include <tvm/relay/type.h>
#include <string>
......
......@@ -165,113 +165,6 @@ class Var : public Expr {
};
/*!
* \brief Function (subgraph in computational graph)
*/
class Function;
/*! \brief Function container */
class FunctionNode : public BaseFuncNode {
public:
/*! \brief Function parameters */
tvm::Array<Var> params;
/*!
* \brief
* The expression which represents the computation of the function,
* the expression may reference the parameters, and the type of it
* or sub-expressions may reference the type variables.
*/
Expr body;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*!
* \brief Type parameters of the function.
* Enables the function to vary its type based on these.
* This corresponds to template paramaters in c++'s terminology.
*
* \note This can be usually empty for non-polymorphic functions.
*/
tvm::Array<TypeVar> type_params;
/*!
* \brief The attributes which store metadata about functions.
*/
tvm::Attrs attrs;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
/*!
* \brief Return the derived function annotation of this expression.
*
* \return The function type annotation.
* \note The function type annotation can contain IncompleteType.
*/
TVM_DLL FuncType func_type_annotation() const;
/*!
* \brief Check whether the function is a primitive function.
*
* \return Whether the function is primitive or not.
*/
bool IsPrimitive() const;
/*!
* \brief Check whether the function is marked as inline.
*
* \return Whether the function should be inlined or not.
*/
bool IsMarkedInline() const;
/*!
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
*
* \return Whether the function will be compiled using the default compiler
* (e.g. those are used in the TVM stack).
*/
bool UseDefaultCompiler() const;
TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> ty_params,
tvm::Attrs attrs = Attrs());
/*!
* \brief Attach the function's parameters to its attributes for use in analysis.
* \return The function with its parameters attached.
*/
Function SetParams(const tvm::Map<Var, Constant>& parameters) const;
/*!
* \brief Retrieve the function's parameters.
*
* \return The function's parameter.
*/
tvm::Map<Var, Constant> GetParams() const;
static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
};
class Function : public BaseFunc {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
};
TVM_DLL ObjectRef FunctionGetAttr(const Function& func, const std::string& key);
TVM_DLL Function FunctionSetAttr(const Function& func,
const std::string& key,
const ObjectRef& data);
/*!
* \brief Call corresponds to operator invocation.
* Corresponds to the operator in computational graph terminology.
*/
......@@ -550,30 +443,6 @@ class TempExpr : public Expr {
TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode);
};
/*! \brief namespace of the attributes that are attached to a function. */
namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
/*!
* \brief Indicate the compiler that should be used for builing this function.
* When this is unset or set to "default", the default compilation pipeline will be used.
*/
constexpr const char* kCompiler = "Compiler";
/*! \brief Indicate if the function is a closure. */
constexpr const char* kClosure = "Closure";
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
constexpr const char* kParams = "__params__";
/*! \brief Store the unique external symbol for external compilers. */
constexpr const char* kExternalSymbol = "ExternalSymbol";
/*! \brief Mark if the function should be avoided being optimized. */
constexpr const char* kSkipOptimization = "SkipOptimization";
/*! \brief Treat the function as a composite operator. */
constexpr const char* kComposite = "Composite";
/*! \brief Mark the function to be inlined. */
constexpr const char* kInline = "Inline";
} // namespace attr
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
......@@ -27,16 +27,15 @@
#include <tvm/node/functor.h>
#include <tvm/ir/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/op.h>
#include <string>
#include <utility>
#include <unordered_map>
#include "./expr.h"
#include "./adt.h"
#include "./op.h"
namespace tvm {
namespace relay {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/function.h
* \brief Relay Function.
*/
#ifndef TVM_RELAY_FUNCTION_H_
#define TVM_RELAY_FUNCTION_H_
#include <tvm/ir/function.h>
#include <tvm/relay/expr.h>
#include <string>
namespace tvm {
namespace relay {
/*!
* \brief Relay Function container
* \sa Function
*/
class FunctionNode : public BaseFuncNode {
public:
/*! \brief Function parameters */
tvm::Array<Var> params;
/*!
* \brief
* The expression which represents the computation of the function,
* the expression may reference the parameters, and the type of it
* or sub-expressions may reference the type variables.
*/
Expr body;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*!
* \brief Type parameters of the function.
* Enables the function to vary its type based on these.
* This corresponds to template paramaters in c++'s terminology.
*
* \note This can be usually empty for non-polymorphic functions.
*/
tvm::Array<TypeVar> type_params;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
/*!
* \brief Return the derived function annotation of this expression.
*
* \return The function type annotation.
* \note The function type annotation can contain IncompleteType.
*/
TVM_DLL FuncType func_type_annotation() const;
/*!
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
*
* \return Whether the function will be compiled using the default compiler
* (e.g. those are used in the TVM stack).
*/
bool UseDefaultCompiler() const;
static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
};
/*!
* \brief Managed reference to FunctionNode.
* \sa FunctionNode
*/
class Function : public BaseFunc {
public:
/*!
* \brief Constructor
* \param params The parameters of the function.
* \param body The body of the function.
* \param ret_type The return type of the function.
* \param ty_params The type parameters.
* \param attrs Additional function attributes.
*/
TVM_DLL Function(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> ty_params,
tvm::DictAttrs attrs = NullValue<DictAttrs>());
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
};
/*!
* \brief Create a new function that copies func, but overrides
* the attribute value key with the value.
*
* \param func The input function.
* \param attr_key The attribute key.
* \param attr_value The value attribute value.
*
* \returns The new function with updated attributes.
*
* \note This function performs copy on write optimization for func.
* If we move a uniquely referenced func into WithAttr,
* then no additional copy will be performed.
*
* This is also why we make it as a function instead of a member function
* and why we pass by value in the first argument.
*
* \code
*
* // Recommended way to trigger copy on write
* func = WithAttr(std::move(func), "key1", value1);
* func = WithAttr(std::move(func), "key2", value2);
*
* \endcode
*/
TVM_DLL Function WithAttr(Function func, const std::string& attr_key, ObjectRef attr_value);
/*!
* \brief namespace of the attributes that can be attached to a relay::Function.
*/
namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
/*!
* \brief Indicate the compiler that should be used for builing this function.
* When this is unset or set to "default", the default compilation pipeline will be used.
*/
constexpr const char* kCompiler = "Compiler";
/*! \brief Indicate if the function is a closure. */
constexpr const char* kClosure = "Closure";
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
constexpr const char* kParams = "__params__";
/*! \brief Store the unique external symbol for external compilers. */
constexpr const char* kExternalSymbol = "ExternalSymbol";
/*! \brief Mark if the function should be avoided being optimized. */
constexpr const char* kSkipOptimization = "SkipOptimization";
/*! \brief Treat the function as a composite operator. */
constexpr const char* kComposite = "Composite";
/*! \brief Mark the function to be inlined. */
constexpr const char* kInline = "Inline";
} // namespace attr
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_FUNCTION_H_
......@@ -27,6 +27,7 @@
#include <tvm/relay/attrs/transform.h>
#include <tvm/ir/transform.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/op.h>
......
......@@ -24,7 +24,7 @@ from .type_relation import TypeCall, TypeRelation
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, BaseFunc, Range
from .adt import Constructor, TypeData
from .module import IRModule
from .attrs import Attrs, make_node
from .attrs import Attrs, DictAttrs, make_node
from .container import Array, Map
from . import transform
......@@ -47,9 +47,7 @@ class Attrs(Object):
keys : list of str
List of keys
"""
fields = self.list_field_info()
for field in fields:
yield field.name
return [field.name for field in self.list_field_info()]
def get_int_tuple(self, key):
"""Get a python int tuple of a key
......@@ -93,6 +91,39 @@ class Attrs(Object):
def __getitem__(self, item):
return self.__getattr__(item)
@tvm._ffi.register_object
class DictAttrs(Attrs):
"""Dictionary attributes.
"""
def _dict(self):
"""Get internal dict"""
return _ffi_api.DictAttrsGetDict(self)
def keys(self):
"""Get list of names in the attribute.
Returns
-------
keys : list of str
List of keys
"""
return [k for k, _ in self.items()]
def __getitem__(self, k):
return self._dict().__getitem__(k)
def __contains__(self, k):
return self._dict().__contains__(k)
def items(self):
"""Get items from the map."""
return self._dict().items()
def __len__(self):
return self._dict().__len__()
def make_node(type_key, **kwargs):
"""Make a new IR node by its type key and fields
......
......@@ -53,6 +53,11 @@ class RelayExpr(BaseExpr):
class BaseFunc(RelayExpr):
"""Base class of all functions."""
@property
def attrs(self):
"""Return the attrs member of the function.
"""
return _ffi_api.BaseFunc_Attrs(self)
@tvm._ffi.register_object("relay.GlobalVar")
......
......@@ -266,22 +266,24 @@ class Function(BaseFunc):
"""
return Call(self, args, None, None)
def get_params(self):
return _expr.FunctionGetParams(self)
def with_attr(self, attr_key, attr_value):
"""Create a new copy of the function and update the attribute
def set_params(self, params):
for key in params:
value = params[key]
if isinstance(value, NDArray):
params[key] = Constant(value)
Parameters
----------
attr_key : str
The attribute key to use.
return _expr.FunctionSetParams(self, params)
attr_value : Object
The new attribute value.
def set_attribute(self, name, ref):
return _expr.FunctionSetAttr(self, name, ref)
Returns
-------
func : Function
A new copy of the function
"""
return _expr.FunctionWithAttr(self, attr_key, attr_value)
def get_attribute(self, name):
return _expr.FunctionGetAttr(self, name)
@register_relay_node
......
......@@ -18,7 +18,7 @@
*/
/*!
* \file src/tvm/ir/adt.cc
* \file src/ir/adt.cc
* \brief ADT type definitions.
*/
#include <tvm/relay/type.h>
......
......@@ -53,22 +53,27 @@ Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
return {};
}
Attrs DictAttrsNode::make(Map<std::string, ObjectRef> dict) {
DictAttrs::DictAttrs(Map<std::string, ObjectRef> dict) {
ObjectPtr<DictAttrsNode> n = make_object<DictAttrsNode>();
n->dict = std::move(dict);
return Attrs(n);
data_ = std::move(n);
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const DictAttrsNode*>(node.get());
p->stream << op->dict;
auto* op = static_cast<const DictAttrsNode*>(node.get());
p->stream << op->dict;
});
TVM_REGISTER_NODE_TYPE(DictAttrsNode);
TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode);
TVM_REGISTER_GLOBAL("ir.DictAttrsGetDict")
.set_body_typed([](DictAttrs attrs) {
return attrs->dict;
});
using namespace tir;
// Equal handler.
......
......@@ -18,11 +18,12 @@
*/
/*!
* \file src/tvm/ir/expr.cc
* \file src/ir/expr.cc
* \brief The expression AST nodes for the common IR infra.
*/
#include <tvm/runtime/registry.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
// NOTE: reverse dependency on top/tir.
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/ir/function.cc
* \brief The function data structure.
*/
#include <tvm/runtime/registry.h>
#include <tvm/ir/function.h>
namespace tvm {
TVM_REGISTER_GLOBAL("ir.BaseFunc_Attrs")
.set_body_typed([](BaseFunc func) {
return func->attrs;
});
} // namespace tvm
......@@ -138,7 +138,7 @@ relay::Function RunTypeCheck(const IRModule& mod,
<< std::endl;
}
func =
relay::FunctionNode::make(concat(func->params, fv),
relay::Function(concat(func->params, fv),
func->body,
func->ret_type,
concat(func->type_params, ftv),
......@@ -296,7 +296,7 @@ IRModule IRModule::FromExpr(
if (auto* func_node = expr.as<relay::FunctionNode>()) {
func = GetRef<relay::Function>(func_node);
} else {
func = relay::FunctionNode::make(
func = relay::Function(
relay::FreeVars(expr), expr, Type(),
relay::FreeTypeVars(expr, mod), {});
}
......@@ -363,7 +363,7 @@ TVM_REGISTER_GLOBAL("ir.Module_Add")
auto func = mod_copy->Lookup(gv->name_hint);
mod->Add(var, Downcast<relay::Function>(func), update);
} else {
auto func = relay::FunctionNode::make({}, Downcast<RelayExpr>(val), Type(nullptr), {});
auto func = relay::Function({}, Downcast<RelayExpr>(val), Type(nullptr), {});
mod->Add(var, func, update);
}
*ret = mod;
......
......@@ -18,7 +18,7 @@
*/
/*!
* \file src/tvm/ir/op.cc
* \file src/ir/op.cc
* \brief Primitive operators and intrinsics.
*/
#include <tvm/ir/op.h>
......
......@@ -18,7 +18,7 @@
*/
/*!
* \file src/tvm/ir/tensor_type.cc
* \file src/ir/tensor_type.cc
* \brief The type system AST nodes of Relay.
*/
#include <tvm/runtime/registry.h>
......
......@@ -18,7 +18,7 @@
*/
/*!
* \file src/tvm/ir/type.cc
* \file src/ir/type.cc
* \brief Common type system AST nodes throughout the IR.
*/
#include <tvm/ir/type.h>
......
......@@ -18,7 +18,7 @@
*/
/*!
* \file src/tvm/ir/type_relation.cc
* \file src/ir/type_relation.cc
* \brief Type relation
*/
#include <tvm/ir/type.h>
......
......@@ -18,7 +18,7 @@
*/
/*!
* \file src/tvm/relay/doc.cc
* \file src/relay/doc.cc
* \brief Doc ADT used for pretty printing.
*
* Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98
......
......@@ -30,6 +30,7 @@
#include <tvm/ir/module.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/runtime/object.h>
#include <memory>
#include <string>
......
......@@ -63,7 +63,7 @@ FeatureSet DetectFeature(const Expr& expr) {
DETECT_DEFAULT_CONSTRUCT(Tuple)
DETECT_DEFAULT_CONSTRUCT(TupleGetItem)
DETECT_CONSTRUCT(Function, {
if (!op->IsPrimitive()) {
if (!op->HasNonzeroAttr(attr::kPrimitive)) {
ExprVisitor::VisitExpr_(op);
}
})
......
......@@ -666,7 +666,7 @@ TVM_REGISTER_GLOBAL("relay._analysis._test_type_solver")
ErrorReporter *err_reporter = new ErrorReporter();
auto module = IRModule({}, {});
auto dummy_fn_name = GlobalVar("test");
module->Add(dummy_fn_name, FunctionNode::make({}, TupleNode::make({}), Type(), {}, {}));
module->Add(dummy_fn_name, Function({}, TupleNode::make({}), Type(), {}, {}));
auto solver = std::make_shared<TypeSolver>(dummy_fn_name, module, err_reporter);
auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc {
......
......@@ -617,15 +617,14 @@ class CompileEngineImpl : public CompileEngineNode {
auto src_func = it.first->source_func;
CHECK(src_func.defined());
if (!src_func->UseDefaultCompiler()) {
auto compiler = FunctionGetAttr(src_func, attr::kCompiler);
const tvm::tir::StringImmNode* code_gen = compiler.as<tvm::tir::StringImmNode>();
CHECK(code_gen) << "No external codegen is set";
auto code_gen = src_func->GetAttr<tir::StringImm>(attr::kCompiler);
CHECK(code_gen.defined()) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) {
ext_mods[code_gen->value] = IRModule({}, {});
}
auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol);
const tvm::tir::StringImmNode* symbol_name = ext_symbol.as<tvm::tir::StringImmNode>();
CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false);
auto symbol_name = src_func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
CHECK(symbol_name.defined()) << "No external symbol is set for:\n"
<< AsText(src_func, false);
auto gv = GlobalVar(symbol_name->value);
ext_mods[code_gen->value]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first);
......@@ -694,8 +693,9 @@ class CompileEngineImpl : public CompileEngineNode {
if (!key->source_func->UseDefaultCompiler()) {
auto cache_node = make_object<CachedFuncNode>();
const auto name_node =
FunctionGetAttr(key->source_func, attr::kExternalSymbol).as<tvm::tir::StringImmNode>();
CHECK(name_node != nullptr) << "External function has not been attached a name yet.";
key->source_func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
CHECK(name_node.defined())
<< "External function has not been attached a name yet.";
cache_node->func_name = name_node->value;
cache_node->target = tvm::target::ext_dev();
value->cached_func = CachedFunc(cache_node);
......
......@@ -68,8 +68,8 @@ class CSourceModuleCodegenBase {
*/
std::string GetExtSymbol(const Function& func) const {
const auto name_node =
FunctionGetAttr(func, attr::kExternalSymbol).as<tvm::tir::StringImmNode>();
CHECK(name_node != nullptr) << "Fail to retrieve external symbol.";
func->GetAttr<tir::StringImm>(attr::kExternalSymbol);
CHECK(name_node.defined()) << "Fail to retrieve external symbol.";
std::string ext_symbol = name_node->value;
return ext_symbol;
}
......
......@@ -415,7 +415,7 @@ class GraphRuntimeCodegen
} else {
LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
}
if (!func->IsPrimitive()) {
if (!func->HasNonzeroAttr(attr::kPrimitive)) {
LOG(FATAL) << "TVM only support calls to primitive functions "
<< "(i.e functions composed of fusable operator invocations)";
}
......
......@@ -18,7 +18,7 @@
*/
/*!
* \file src/tvm/relay/interpreter.cc
* \file src/relay/interpreter.cc
* \brief An interpreter for the Relay IR.
*/
#include <tvm/runtime/device_api.h>
......@@ -516,7 +516,7 @@ class Interpreter :
}
if (is_dyn) {
CHECK(func->IsPrimitive());
CHECK(func->HasNonzeroAttr(attr::kPrimitive));
out_shapes = ComputeDynamicShape(func, args);
}
......@@ -556,7 +556,7 @@ class Interpreter :
const tvm::Array<ObjectRef>& args,
const Var& bind = Var()) {
// Get a reference to the function inside the closure.
if (closure->func->IsPrimitive()) {
if (closure->func->HasNonzeroAttr(attr::kPrimitive)) {
return InvokePrimitiveOp(closure->func, args);
}
auto func = closure->func;
......
......@@ -442,7 +442,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
const Expr& outputs) {
std::vector<Index> argument_registers;
CHECK(func->IsPrimitive())
CHECK_NE(func->GetAttr<Integer>(attr::kPrimitive, 0)->value, 0)
<< "internal error: invoke_tvm_op requires the first argument to be a relay::Function";
auto input_tuple = inputs.as<TupleNode>();
......@@ -650,7 +650,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
}
void VisitExpr_(const FunctionNode* func_node) {
if (!func_node->IsPrimitive()) {
if (!func_node->HasNonzeroAttr(attr::kPrimitive)) {
LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl
<< "Program: " << AsText(GetRef<Function>(func_node), false) << std::endl
<< "AST: " << GetRef<Function>(func_node);
......
......@@ -86,7 +86,7 @@ struct PrimitiveInliner : ExprMutator {
}
if (auto func = op.as<FunctionNode>()) {
if (func->IsPrimitive()) {
if (func->HasNonzeroAttr(attr::kPrimitive)) {
tvm::Array<Expr> call_args;
for (auto arg : call->args) {
auto new_arg = VisitExpr(arg);
......@@ -109,7 +109,7 @@ struct PrimitiveInliner : ExprMutator {
}
Expr VisitExpr_(const FunctionNode* func) {
if (func->IsPrimitive()) {
if (func->HasNonzeroAttr(attr::kPrimitive)) {
return GetRef<Function>(func);
} else {
return ExprMutator::VisitExpr_(func);
......@@ -128,7 +128,7 @@ struct PrimitiveInliner : ExprMutator {
DLOG(INFO) << "Before inlining primitives: " << global
<< std::endl << AsText(func, false);
func = FunctionNode::make(func->params,
func = Function(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
......
......@@ -43,13 +43,11 @@ inline std::string GenerateName(const Function& func) {
}
bool IsClosure(const Function& func) {
ObjectRef res = FunctionGetAttr(func, attr::kClosure);
const tir::IntImmNode* pval = res.as<tir::IntImmNode>();
return pval && pval->value != 0;
return func->GetAttr<Integer>(attr::kClosure, 0)->value != 0;
}
Function MarkClosure(const Function& func) {
return FunctionSetAttr(func, attr::kClosure, tvm::Integer(1));
Function MarkClosure(Function func) {
return WithAttr(std::move(func), attr::kClosure, tvm::Integer(1));
}
/* The goal of this class is to lift out any nested functions into top-level
......@@ -65,7 +63,7 @@ class LambdaLifter : public ExprMutator {
Expr VisitExpr_(const LetNode* let_node) final {
bool is_lambda = false;
if (auto func = let_node->value.as<FunctionNode>()) {
if (!func->IsPrimitive()) {
if (!func->HasNonzeroAttr(attr::kPrimitive)) {
is_lambda = true;
letrec_.push_back(let_node->var);
}
......@@ -96,7 +94,7 @@ class LambdaLifter : public ExprMutator {
auto func = GetRef<Function>(func_node);
// We should not transform primitive functions.
if (func->IsPrimitive()) {
if (func->HasNonzeroAttr(attr::kPrimitive)) {
return std::move(func);
}
......@@ -151,10 +149,10 @@ class LambdaLifter : public ExprMutator {
// code for the closure.
Function lifted_func;
if (captured_vars.size() == 0 && free_type_vars.size() == 0) {
lifted_func = FunctionNode::make(body->params, body->body, body->ret_type, body->type_params);
lifted_func = Function(body->params, body->body, body->ret_type, body->type_params);
} else {
lifted_func =
FunctionNode::make(captured_vars, body, func->func_type_annotation(), free_type_vars);
Function(captured_vars, body, func->func_type_annotation(), free_type_vars);
lifted_func = MarkClosure(lifted_func);
}
......@@ -191,7 +189,7 @@ class LambdaLifter : public ExprMutator {
if (auto* n = pair.second.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
auto func = GetRef<Function>(n);
func = FunctionNode::make(func->params,
func = Function(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
......
......@@ -18,7 +18,7 @@
*/
/*!
* \file src/tvm/ir/adt.cc
* \file src/ir/adt.cc
* \brief AST nodes for Relay algebraic data types (ADTs).
*/
#include <tvm/relay/type.h>
......
......@@ -18,7 +18,7 @@
*/
/*!
* \file src/tvm/relay/ir/expr.cc
* \file src/relay/ir/expr.cc
* \brief The expression AST nodes of Relay.
*/
#include <tvm/ir/module.h>
......@@ -110,118 +110,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ")";
});
Function FunctionNode::make(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> type_params,
tvm::Attrs attrs) {
ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
CHECK(params.defined());
CHECK(type_params.defined());
n->params = std::move(params);
n->body = std::move(body);
n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params);
n->attrs = std::move(attrs);
return Function(n);
}
FuncType FunctionNode::func_type_annotation() const {
Array<Type> param_types;
for (auto param : this->params) {
Type param_type = (param->type_annotation.defined()) ? param->type_annotation
: IncompleteType(Kind::kType);
param_types.push_back(param_type);
}
Type ret_type = (this->ret_type.defined()) ? this->ret_type
: IncompleteType(Kind::kType);
return FuncType(param_types, ret_type, this->type_params, {});
}
bool FunctionNode::IsPrimitive() const {
ObjectRef res = FunctionGetAttr(GetRef<Function>(this), attr::kPrimitive);
const tir::IntImmNode* pval = res.as<tir::IntImmNode>();
return pval && pval->value != 0;
}
bool FunctionNode::IsMarkedInline() const {
ObjectRef res = FunctionGetAttr(GetRef<Function>(this), attr::kInline);
const tir::IntImmNode* pval = res.as<tir::IntImmNode>();
return pval && pval->value != 0;
}
Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) const {
return FunctionSetAttr(GetRef<Function>(this), attr::kParams, parameters);
}
TVM_REGISTER_GLOBAL("relay._expr.FunctionSetParams")
.set_body_typed(
[](const Function& func, const tvm::Map<Var, Constant>& parameters) {
return func->SetParams(parameters);
});
tvm::Map<Var, Constant> FunctionNode::GetParams() const {
auto node_ref = FunctionGetAttr(GetRef<Function>(this), attr::kParams);
return Downcast<tvm::Map<Var, Constant>>(node_ref);
}
TVM_REGISTER_GLOBAL("relay._expr.FunctionGetParams")
.set_body_typed([](const Function& func) {
return func->GetParams();
});
bool FunctionNode::UseDefaultCompiler() const {
ObjectRef res = FunctionGetAttr(GetRef<Function>(this), attr::kCompiler);
const tir::StringImmNode* pval = res.as<tir::StringImmNode>();
return pval == nullptr || pval->value == "default";
}
ObjectRef FunctionGetAttr(const Function& func, const std::string& key) {
if (!func->attrs.defined()) { return ObjectRef(); }
const DictAttrsNode* dict_attrs = func->attrs.as<DictAttrsNode>();
CHECK(dict_attrs);
auto it = dict_attrs->dict.find(key);
if (it != dict_attrs->dict.end()) {
return (*it).second;
} else {
return ObjectRef();
}
}
Function FunctionSetAttr(const Function& func, const std::string& key, const ObjectRef& data) {
const DictAttrsNode* dattrs = func->attrs.as<DictAttrsNode>();
Attrs func_attrs;
if (dattrs) {
Map<std::string, ObjectRef> dict = dattrs->dict;
dict.Set(key, data);
func_attrs = DictAttrsNode::make(dict);
} else {
Map<std::string, ObjectRef> dict = {{key, data}};
func_attrs = DictAttrsNode::make(dict);
}
return FunctionNode::make(
func->params,
func->body,
func->ret_type,
func->type_params,
func_attrs);
}
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_GLOBAL("relay._make.Function")
.set_body_typed(FunctionNode::make);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const FunctionNode*>(ref.get());
p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
<< ", " << node->body << ", " << node->type_params << ", "
<< node->attrs << ")";
});
Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
Array<Type> type_args) {
......@@ -360,18 +248,6 @@ TVM_REGISTER_GLOBAL("relay._expr.TempExprRealize")
return temp->Realize();
});
TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr")
.set_body_typed(
[](Function func, std::string name, ObjectRef ref) {
return FunctionSetAttr(func, name, ref);
});
TVM_REGISTER_GLOBAL("relay._expr.FunctionGetAttr")
.set_body_typed(
[](Function func, std::string name) {
return FunctionGetAttr(func, name);
});
TVM_REGISTER_GLOBAL("relay._make.Any")
.set_body_typed([]() { return Any::make(); });
......
......@@ -18,7 +18,7 @@
*/
/*!
* \file src/tvm/relay/expr_functor.cc
* \file src/relay/expr_functor.cc
* \brief A wrapper around ExprFunctor which functionally updates the AST.
*
* ExprMutator uses memoization and self return in order to amortize
......@@ -109,7 +109,7 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return FunctionNode::make(params, body, ret_type, ty_params, op->attrs);
return Function(params, body, ret_type, ty_params, op->attrs);
}
}
......@@ -417,7 +417,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
new_params.size() == func->params.size()) {
return expr;
}
auto ret = FunctionNode::make(new_params,
auto ret = Function(new_params,
new_body,
func->ret_type,
func->type_params,
......@@ -431,7 +431,7 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
new_params.push_back(v);
}
}
ret = FunctionNode::make(new_params,
ret = Function(new_params,
new_body,
func->ret_type,
func->type_params,
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/ir/function.cc
* \brief Function in relay.
*/
#include <tvm/relay/function.h>
namespace tvm {
namespace relay {
Function::Function(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> type_params,
DictAttrs attrs) {
ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
CHECK(params.defined());
CHECK(type_params.defined());
n->params = std::move(params);
n->body = std::move(body);
n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params);
n->attrs = std::move(attrs);
data_ = std::move(n);
}
FuncType FunctionNode::func_type_annotation() const {
Array<Type> param_types;
for (auto param : this->params) {
Type param_type = (param->type_annotation.defined()) ? param->type_annotation
: IncompleteType(Kind::kType);
param_types.push_back(param_type);
}
Type ret_type = (this->ret_type.defined()) ? this->ret_type
: IncompleteType(Kind::kType);
return FuncType(param_types, ret_type, this->type_params, {});
}
bool FunctionNode::UseDefaultCompiler() const {
tir::StringImm val = this->GetAttr<tir::StringImm>(attr::kCompiler);
return !val.defined() || val->value == "default";
}
Function WithAttr(Function func, const std::string& attr_key, ObjectRef attr_value) {
FunctionNode* node = func.CopyOnWrite();
if (node->attrs.defined()) {
node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
} else {
Map<std::string, ObjectRef> dict = {{attr_key, attr_value}};
node->attrs = DictAttrs(dict);
}
return func;
}
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_GLOBAL("relay._make.Function")
.set_body_typed([](tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> ty_params,
tvm::DictAttrs attrs) {
return Function(params, body, ret_type, ty_params, attrs);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const FunctionNode*>(ref.get());
p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
<< ", " << node->body << ", " << node->type_params << ", "
<< node->attrs << ")";
});
TVM_REGISTER_GLOBAL("relay._expr.FunctionWithAttr")
.set_body_typed(
[](Function func, std::string name, ObjectRef ref) {
return WithAttr(std::move(func), name, ref);
});
} // namespace relay
} // namespace tvm
......@@ -18,7 +18,7 @@
*/
/*!
* \file src/tvm/relay/ir/hash.cc
* \file src/relay/ir/hash.cc
* \brief Hash functions for Relay types and expressions.
*/
#include <tvm/ir/type_functor.h>
......
......@@ -18,7 +18,7 @@
*/
/*!
* \file src/tvm/relay/ir/op_strategy.cc
* \file src/relay/ir/op_strategy.cc
* \brief The Relay operator Strategy and related data structure.
*/
......
......@@ -139,9 +139,8 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
}
bool FunctionPassNode::SkipFunction(const Function& func) const {
ObjectRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization);
const tir::IntImmNode* pval = skip_opt.as<tir::IntImmNode>();
return (pval && pval->value != 0) || (!func->UseDefaultCompiler());
return func->GetAttr<Integer>(attr::kSkipOptimization, 0)->value != 0 ||
!(func->UseDefaultCompiler());
}
Pass CreateFunctionPass(
......
......@@ -99,7 +99,7 @@ Pass QuantizeAnnotate() {
for (const auto& x : FreeVars(func)) {
new_params.push_back(x);
}
return FunctionNode::make(new_params,
return Function(new_params,
func->body,
func->ret_type,
func->type_params,
......
......@@ -151,7 +151,7 @@ class StatsCollector : private ExprMutator {
const FunctionNode* func = new_e.as<FunctionNode>();
CHECK(func) << "Input shoule be Function";
Expr new_body = TupleNode::make(std::move(profile_data_));
return FunctionNode::make(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
return Function(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
func->attrs);
}
......
......@@ -78,7 +78,7 @@ Expr DeDup(const Expr& e) {
for (const Var& param : op->params) {
params.push_back(Fresh(param));
}
return FunctionNode::make(params,
return Function(params,
VisitExpr(op->body),
VisitType(op->ret_type),
type_params,
......
......@@ -521,13 +521,13 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) {
}
CHECK_GT(new_body.size(), 0U);
if (new_body.size() == 1) {
return FunctionNode::make(params, new_body[0], Type(nullptr),
return Function(params, new_body[0], Type(nullptr),
fn->type_params, fn->attrs);
} else if (tuple->fields.size() == new_body.size()) {
return new_expr;
} else {
Tuple tuple_body = TupleNode::make(new_body);
return FunctionNode::make(params, tuple_body, Type(nullptr),
return Function(params, tuple_body, Type(nullptr),
fn->type_params, fn->attrs);
}
} else {
......
......@@ -111,7 +111,7 @@ class EtaExpander : public ExprMutator {
Expr body = CallNode::make(cons, params, Attrs());
Type ret_type = TypeCall(cons->belong_to, type_params);
return FunctionNode::make(
return Function(
Downcast<tvm::Array<Var>>(params),
body,
ret_type,
......@@ -135,7 +135,7 @@ class EtaExpander : public ExprMutator {
args.push_back(var);
}
return FunctionNode::make(
return Function(
args,
CallNode::make(gvar, params),
func->ret_type,
......
......@@ -209,7 +209,7 @@ class ConstantFolder : public ExprMutator {
func = Downcast<Function>(expr);
} else {
// TODO(@jroesch): fix this
func = FunctionNode::make(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {});
func = Function(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {});
}
auto mod = IRModule(
{},
......
......@@ -852,7 +852,7 @@ class FuseMutator : private ExprMutator {
// Skip primitive function.
Expr VisitExpr_(const FunctionNode* fn_node) {
if (fn_node->IsPrimitive()) {
if (fn_node->HasNonzeroAttr(attr::kPrimitive)) {
return GetRef<Expr>(fn_node);
} else {
return ExprMutator::VisitExpr_(fn_node);
......@@ -932,8 +932,8 @@ class FuseMutator : private ExprMutator {
} visitor;
visitor(body);
const GroupInfo& ginfo = ginfo_[group];
auto func = FunctionNode::make(ginfo.params, body, ret_type, {});
func = FunctionSetAttr(func, attr::kPrimitive, tvm::Integer(visitor.has_call));
auto func = Function(ginfo.params, body, ret_type, {});
func = WithAttr(std::move(func), attr::kPrimitive, tvm::Integer(visitor.has_call));
return CallNode::make(func, ginfo.arguments, Attrs());
}
......
......@@ -255,7 +255,7 @@ Expr FirstOrderGradient(const Expr& re, const IRModule& mod) {
return Pair(res.forward, grad);
});
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
return Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
}
TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient")
......@@ -384,7 +384,7 @@ void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
}
Expr BPEmpty() {
Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleType::Empty(), {});
Expr unitF = Function({}, TupleNode::make({}), TupleType::Empty(), {});
return RefCreateNode::make(unitF);
}
......@@ -413,7 +413,7 @@ struct ReverseAD : ExprMutator {
auto x_var = ll->Push(x);
auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make(
Expr nbp = Function(
{},
LetList::With([&](LetList* ll) {
// we need a new ReverseAD visitor to avoid clobbering the bp local var
......@@ -457,7 +457,7 @@ struct ReverseAD : ExprMutator {
orig_var->checked_type_ = call->checked_type();
auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make(
Expr nbp = Function(
{},
LetList::With([&](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
......@@ -583,7 +583,7 @@ Expr Gradient(const Expr& re, const IRModule& mod) {
};
return Pair(get_final_result(c, f->body->checked_type()), TupleNode::make(ret));
});
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
return Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
}
TVM_REGISTER_GLOBAL("relay._transform.gradient")
......
......@@ -83,7 +83,7 @@ class Inliner : ExprMutator {
}
Function Inline(const Function& func) {
return FunctionNode::make(func->params,
return Function(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
......@@ -101,7 +101,7 @@ class Inliner : ExprMutator {
if (!func->body.defined()) return false;
// The function must be annotated with the inline attribute.
if (!func->IsMarkedInline()) return false;
if (!func->HasNonzeroAttr(attr::kInline)) return false;
// The function is not abled to be inlined if any callee under the CallGraph
// of this function cannot be inlined.
......@@ -124,7 +124,7 @@ class Inliner : ExprMutator {
const auto* fn = base_func.as<FunctionNode>();
CHECK(fn) << "Expected to work on a Relay function.";
auto func = FunctionNode::make(fn->params,
auto func = Function(fn->params,
fn->body,
fn->ret_type,
fn->type_params,
......@@ -198,7 +198,7 @@ IRModule Inline(const IRModule& module) {
auto base_func = cg->GetGlobalFunction(cgn->GetGlobalVar());
if (const auto* fn = base_func.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
if (func->IsMarkedInline()) {
if (func->HasNonzeroAttr(attr::kInline)) {
CHECK_EQ(cgn->GetRefCount(), 0U)
<< cgn->GetNameHint() << " is marked as inline but not inlined.";
cgn->CleanCallGraphEntries();
......
......@@ -140,9 +140,10 @@ class MergeCompositeWrapper : public ExprMutator {
if (call->op->IsInstance<FunctionNode>()) {
Function func = Downcast<Function>(call->op);
CHECK(func.defined());
const auto name_node = FunctionGetAttr(func, attr::kComposite).as<tir::StringImmNode>();
const auto name_node =
func->GetAttr<tir::StringImm>(attr::kComposite);
// don't step into existing composite functions
if (name_node && name_node->value != "") {
if (name_node.defined() && name_node->value != "") {
tvm::Array<tvm::relay::Expr> new_args;
for (const auto& arg : call->args) {
auto new_e = this->Mutate(arg);
......@@ -166,8 +167,8 @@ class MergeCompositeWrapper : public ExprMutator {
if (extract.defined()) {
auto free_vars = FreeVars(extract);
// make the composite function
auto f = FunctionNode::make(free_vars, extract, call->checked_type_, {}, Attrs());
f = FunctionSetAttr(f, attr::kComposite, tir::StringImmNode::make(pattern_name_));
auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs());
f = WithAttr(std::move(f), attr::kComposite, tir::StringImmNode::make(pattern_name_));
// find the expressions associated with the free vars using the args_map
// this tells us which expressions should be given as inputs to the composite function
Array<Expr> args;
......
......@@ -820,7 +820,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
Func VisitFuncStatic(const Function& func, const Expr& var) {
CHECK(IsAtomic(var));
if (func->IsPrimitive()) {
if (func->HasNonzeroAttr(attr::kPrimitive)) {
return ConstEvaluateFunc(func);
}
std::vector<std::pair<Var, PStatic> > free_vars;
......@@ -881,7 +881,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) {
return store_.Extend<Expr>([&]() {
store_.Invalidate();
return FunctionNode::make(func->params,
return Function(func->params,
LetList::With([&](LetList* ll) {
std::vector<PStatic> pv;
for (const auto& v : func->params) {
......
......@@ -211,15 +211,18 @@ class Partitioner : public ExprMutator {
}
auto subgraph_func =
FunctionNode::make(params, input, call->checked_type_, {}, Attrs());
Function(params, input, call->checked_type_, {});
std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id);
subgraph_func =
FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tir::StringImmNode::make(name));
subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1));
subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler,
tvm::tir::StringImmNode::make(compiler_attrs->compiler));
subgraph_func = FunctionSetAttr(subgraph_func, attr::kInline, tvm::Integer(1));
WithAttr(std::move(subgraph_func), attr::kExternalSymbol, tir::StringImmNode::make(name));
subgraph_func =
WithAttr(std::move(subgraph_func), attr::kPrimitive, tvm::Integer(1));
subgraph_func =
WithAttr(std::move(subgraph_func), attr::kCompiler,
tvm::tir::StringImmNode::make(compiler_attrs->compiler));
subgraph_func =
WithAttr(std::move(subgraph_func), attr::kInline, tvm::Integer(1));
CHECK(!module_->ContainGlobalVar(name))
<< "Global function " << name << " already exists";
// Create a global function and add it to the IRModule for the subgraph.
......@@ -277,7 +280,7 @@ class Partitioner : public ExprMutator {
params.push_back(new_param);
}
auto body = VisitExpr(op->body);
return FunctionNode::make(params, body, op->ret_type, op->type_params, op->attrs);
return Function(params, body, op->ret_type, op->type_params, op->attrs);
}
}
......@@ -351,7 +354,7 @@ class Partitioner : public ExprMutator {
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
func = FunctionNode::make(func->params,
func = Function(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
......
......@@ -91,7 +91,7 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map);
*/
inline Expr TransformF(const std::function<Expr(const Expr&)>& func, const Expr& e) {
if (const FunctionNode* f = e.as<FunctionNode>()) {
return FunctionNode::make(f->params, func(f->body), f->ret_type, f->type_params, f->attrs);
return Function(f->params, func(f->body), f->ret_type, f->type_params, f->attrs);
} else {
return func(e);
}
......
......@@ -208,10 +208,10 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Expr VisitExpr_(const FunctionNode* f, const Var& v) final {
Expr e = GetRef<Expr>(f);
Expr ret;
if (f->IsPrimitive()) {
if (f->HasNonzeroAttr(attr::kPrimitive)) {
ret = e;
} else {
ret = FunctionNode::make(f->params,
ret = Function(f->params,
GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)),
f->ret_type,
f->type_params,
......
......@@ -142,7 +142,7 @@ Function ToCPS(const Function& f,
}
Expr VisitExpr_(const FunctionNode* op, const MCont& k) final {
CHECK(!op->IsPrimitive()) << "primitive func not supported yet.";
CHECK(!op->HasNonzeroAttr(attr::kPrimitive)) << "primitive func not supported yet.";
return k(ToCPS(GetRef<Function>(op), m, cm, vm, answer));
}
......@@ -182,7 +182,7 @@ Function ToCPS(const Function& f,
Expr reify(const MCont& k) {
Var arg = VarNode::make("arg", Type());
return FunctionNode::make({arg}, k(arg), Type(), {}, {});
return Function({arg}, k(arg), Type(), {}, {});
}
Expr reify(const MCont& k, const std::function<Expr(MCont)>& cont) {
......@@ -293,7 +293,7 @@ Function ToCPS(const Function& f,
new_params.push_back(remap(v));
}
new_params.push_back(k);
return FunctionNode::make(new_params,
return Function(new_params,
mut.VisitExpr(f->body,
[&](const Expr& e) { return CallNode::make(k, {e}); }),
answer,
......@@ -328,7 +328,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) {
Function ret = ToCPS(f, m, cm, &var, answer);
auto new_type_params = ret->type_params;
new_type_params.push_back(answer);
return FunctionNode::make(ret->params, ret->body, ret->ret_type, new_type_params, ret->attrs);
return Function(ret->params, ret->body, ret->ret_type, new_type_params, ret->attrs);
}
Function ToCPS(const Function& f, const IRModule& m) {
......@@ -355,7 +355,7 @@ Function UnCPS(const Function& f) {
// TODO(@M.K.): make alphaequal work on free term
// CHECK(AlphaEqual(cont_type, Arrow(new_ret_type, answer_type)));
auto x = VarNode::make("x", new_ret_type);
auto cont = FunctionNode::make({x}, x, new_ret_type, {}, {});
auto cont = Function({x}, x, new_ret_type, {}, {});
tvm::Array<Expr> args;
for (const auto& p : new_params) {
args.push_back(p);
......@@ -366,7 +366,7 @@ Function UnCPS(const Function& f) {
type_args.push_back(tp);
}
type_args.push_back(new_ret_type);
return FunctionNode::make(new_params,
return Function(new_params,
CallNode::make(f, args, {}, type_args),
new_ret_type,
new_type_params,
......
......@@ -82,7 +82,7 @@ TEST(Relay, BuildModule) {
auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {});
auto c = relay::VarNode::make("c", tensor_type);
auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {});
auto func = relay::FunctionNode::make(relay::FreeVars(y), y, relay::Type(), {});
auto func = relay::Function(relay::FreeVars(y), y, relay::Type(), {});
auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
......
......@@ -28,11 +28,11 @@ TEST(Relay, SelfReference) {
using namespace tvm;
auto tensor_type = relay::TensorType({}, DataType::Bool());
auto x = relay::VarNode::make("x", relay::Type());
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
auto f = relay::Function(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
CHECK(f->IsInstance<BaseFuncNode>());
auto y = relay::VarNode::make("y", tensor_type);
auto call = relay::CallNode::make(f, Array<relay::Expr>{ y });
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
auto fx = relay::Function(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
auto mod = IRModule::FromExpr(fx);
mod = relay::transform::InferType()(mod);
auto type_fx = mod->Lookup("main");
......
......@@ -53,7 +53,7 @@ TEST(Relay, Sequential) {
// Let expression and varaible a should be dead-code eliminated.
auto z3 = relay::LetNode::make(a, c, z2);
relay::Function func =
relay::FunctionNode::make(relay::FreeVars(z3), z3, relay::Type(), {});
relay::Function(relay::FreeVars(z3), z3, relay::Type(), {});
// Get schedule
auto reg = tvm::runtime::Registry::Get("relay.op._Register");
......@@ -96,7 +96,7 @@ TEST(Relay, Sequential) {
auto zz = relay::CallNode::make(add_op, {y1, c1});
zz = relay::CallNode::make(add_op, {zz, zz});
relay::Function expected_func =
relay::FunctionNode::make(relay::FreeVars(zz), zz, relay::Type(), {});
relay::Function(relay::FreeVars(zz), zz, relay::Type(), {});
// Infer type for the expected function.
auto mod1 = IRModule::FromExpr(expected_func);
......
......@@ -58,7 +58,7 @@ TEST(MicroStandaloneRuntime, BuildModule) {
auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {});
auto c = relay::VarNode::make("c", tensor_type);
auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {});
auto func = relay::FunctionNode::make(relay::FreeVars(y), y, relay::Type(), {});
auto func = relay::Function(relay::FreeVars(y), y, relay::Type(), {});
auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
......
......@@ -134,7 +134,7 @@ def test_recursive_func():
func = relay.Function([i],
sb.get(),
ret_type=relay.TensorType([], 'int32'))
func = func.set_attribute("Compiler", tvm.tir.StringImm("a"))
func = func.with_attr("Compiler", tvm.tir.StringImm("a"))
mod[sum_up] = func
iarg = relay.var('i', shape=[], dtype='int32')
mod["main"] = relay.Function([iarg], sum_up(iarg))
......
......@@ -78,9 +78,9 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
def set_external_func_attr(func, compiler, ext_symbol):
func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler", tvm.tir.StringImm(compiler))
func = func.set_attribute("ExternalSymbol", tvm.tir.StringImm(ext_symbol))
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm(compiler))
func = func.with_attr("ExternalSymbol", tvm.tir.StringImm(ext_symbol))
return func
......
......@@ -307,7 +307,7 @@ def get_synthetic_lib():
gcc_input3 = relay.var('gcc_input3', shape=(10, 10))
subgraph0 = relay.Function([gcc_input0, gcc_input1, gcc_input2,
gcc_input3], relay.copy(gcc_input0))
subgraph0 = subgraph0.set_attribute(
subgraph0 = subgraph0.with_attr(
"Primitive", tvm.tir.IntImm("int32", 1))
# Call subgraph0
......@@ -320,7 +320,7 @@ def get_synthetic_lib():
gcc_input7 = relay.var('gcc_input7', shape=(10, 10))
subgraph1 = relay.Function([gcc_input4, gcc_input5, gcc_input6,
gcc_input7], relay.copy(gcc_input4))
subgraph1 = subgraph1.set_attribute(
subgraph1 = subgraph1.with_attr(
"Primitive", tvm.tir.IntImm("int32", 1))
# Call subgraph1
......
......@@ -169,15 +169,16 @@ def test_function():
body = relay.Tuple(tvm.runtime.convert([]))
type_params = tvm.runtime.convert([])
fn = relay.Function(params, body, ret_type, type_params)
fn = fn.set_attribute("test_attribute", tvm.tir.StringImm("value"))
fn = fn.with_attr("test_attribute", tvm.tir.StringImm("value"))
assert fn.params == params
assert fn.body == body
assert fn.type_params == type_params
assert fn.span == None
assert fn.get_attribute("test_attribute") == "value"
assert fn.attrs["test_attribute"] == "value"
str(fn)
check_json_roundtrip(fn)
@pytest.mark.skip(reason="AttrsEqualHandler doesn't handle Map so far.")
def test_function_attrs():
param_names = ['a', 'b', 'c', 'd']
......@@ -190,8 +191,10 @@ def test_function_attrs():
for param in params[:1]:
cty = param.type_annotation
tensor = np.random.rand(*[int(sh) for sh in cty.shape]).astype(cty.dtype)
model_params[param] = tvm.nd.array(tensor)
fn = fn.set_params(model_params)
model_params[param] = relay.Constant(tvm.nd.array(tensor))
fn = fn.with_attr("__params__", model_params)
assert fn.params == params
assert fn.body == body
assert fn.type_params == type_params
......@@ -200,7 +203,7 @@ def test_function_attrs():
check_json_roundtrip(fn)
json_str = tvm.ir.save_json(fn)
fn_after = tvm.ir.load_json(json_str)
model_params_after = fn_after.get_params()
model_params_after = fn_after.attrs["__params__"]
after_keys = [item[0] for item in model_params_after.items()]
for key1, key2 in zip(model_params, after_keys):
assert key1.name_hint == key2.name_hint
......@@ -296,4 +299,3 @@ if __name__ == "__main__":
test_tuple_get_item()
test_op()
test_conv2d_attrs()
......@@ -353,7 +353,7 @@ def test_function_attr():
p00 = relay.subtract(z00, w01)
q00 = relay.multiply(p00, w02)
func0 = relay.Function([x0, w00, w01, w02], q00)
func0 = func0.set_attribute("FuncName", tvm.tir.StringImm("a"))
func0 = func0.with_attr("FuncName", tvm.tir.StringImm("a"))
x1 = relay.var('x1', shape=(10, 10))
w10 = relay.var('w10', shape=(10, 10))
......@@ -363,7 +363,7 @@ def test_function_attr():
p10 = relay.subtract(z10, w11)
q10 = relay.multiply(p10, w12)
func1 = relay.Function([x1, w10, w11, w12], q10)
func1 = func1.set_attribute("FuncName", tvm.tir.StringImm("b"))
func1 = func1.with_attr("FuncName", tvm.tir.StringImm("b"))
assert not alpha_equal(func0, func1)
......@@ -694,7 +694,7 @@ def test_fn_attribute():
d = relay.var('d', shape=(10, 10))
add_1 = relay.add(c, d)
add_1_fn = relay.Function([c, d], add_1)
add_1_fn = add_1_fn.set_attribute("TestAttribute", tvm.tir.StringImm("test"))
add_1_fn = add_1_fn.with_attr("TestAttribute", tvm.tir.StringImm("test"))
add_1_fn = run_opt_pass(add_1_fn, relay.transform.InferType())
assert not relay.analysis.alpha_equal(add_1_fn, add_fn)
......
......@@ -164,7 +164,7 @@ def test_simple_merge():
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu"))
add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu"))
# merged function
r = relay.Call(add_relu, [a, b])
......@@ -229,7 +229,7 @@ def test_branch_merge():
sub_node = relay.subtract(in_1, in_2)
mul_node = relay.multiply(add_node, sub_node)
add_sub_mul = relay.Function([in_1, in_2], mul_node)
add_sub_mul = add_sub_mul.set_attribute("Composite",
add_sub_mul = add_sub_mul.with_attr("Composite",
tir.StringImm("add_sub_mul"))
# add_sub_mul1 function
......@@ -239,7 +239,7 @@ def test_branch_merge():
sub_node_1 = relay.subtract(in_3, in_4)
mul_node_1 = relay.multiply(add_node_1, sub_node_1)
add_sub_mul_1 = relay.Function([in_3, in_4], mul_node_1)
add_sub_mul_1 = add_sub_mul_1.set_attribute("Composite",
add_sub_mul_1 = add_sub_mul_1.with_attr("Composite",
tir.StringImm("add_sub_mul"))
# merged function
......@@ -299,7 +299,7 @@ def test_reuse_call_merge():
add_node_1 = relay.add(in_1, add_node)
add_node_2 = relay.add(add_node_1, add_node)
add_add_add = relay.Function([in_1, in_2], add_node_2)
add_add_add = add_add_add.set_attribute("Composite",
add_add_add = add_add_add.with_attr("Composite",
tir.StringImm("add_add_add"))
# merged function
......@@ -383,7 +383,7 @@ def test_multiple_patterns():
bias_node = relay.nn.bias_add(conv_node, in_3)
r = relay.nn.relu(bias_node)
conv_bias_add_relu = relay.Function([in_1, in_2, in_3], r)
conv_bias_add_relu = conv_bias_add_relu.set_attribute("Composite",
conv_bias_add_relu = conv_bias_add_relu.with_attr("Composite",
tir.StringImm("conv2d_bias_relu"))
# add_relu function
......@@ -392,7 +392,7 @@ def test_multiple_patterns():
add_node = relay.add(in_4, in_5)
r = relay.nn.relu(add_node)
add_relu = relay.Function([in_4, in_5], r)
add_relu = add_relu.set_attribute("Composite", tir.StringImm("add_relu"))
add_relu = add_relu.with_attr("Composite", tir.StringImm("add_relu"))
# merged function
conv_bias_add_relu_1 = relay.Call(conv_bias_add_relu, [data, kernel, bias])
......@@ -461,7 +461,7 @@ def test_merge_order():
out = relay.abs(out)
out = relay.nn.relu(out)
merged_func = relay.Function([x, y], out)
merged_func = merged_func.set_attribute('Composite',
merged_func = merged_func.with_attr('Composite',
tir.StringImm(composite_name))
ret = relay.Call(merged_func, [input_1, input_2])
return relay.Function([input_1, input_2], ret)
......@@ -527,13 +527,13 @@ def test_parallel_merge():
y = relay.var('y')
branch_1 = relay.multiply(relay.add(x, y), relay.subtract(x, y))
func_1 = relay.Function([x, y], branch_1)
func_1 = func_1.set_attribute('Composite', tir.StringImm("add_sub_mul"))
func_1 = func_1.with_attr('Composite', tir.StringImm("add_sub_mul"))
call_1 = relay.Call(func_1, [input_1, input_2])
x1 = relay.var('x1')
y1 = relay.var('y1')
branch_2 = relay.multiply(relay.add(x1, y1), relay.subtract(x1, y1))
func_2 = relay.Function([x1, y1], branch_2)
func_2 = func_2.set_attribute('Composite', tir.StringImm("add_sub_mul"))
func_2 = func_2.with_attr('Composite', tir.StringImm("add_sub_mul"))
call_2 = relay.Call(func_2, [input_1, input_2])
out = relay.multiply(call_1, call_2)
return relay.Function([input_1, input_2], out)
......@@ -612,14 +612,14 @@ def test_multiple_input_subgraphs():
add_relu_1 = relay.add(x, y)
add_relu_1 = relay.nn.relu(add_relu_1)
add_relu_1 = relay.Function([x, y], add_relu_1)
add_relu_1 = add_relu_1.set_attribute('Composite', tir.StringImm('add_relu'))
add_relu_1 = add_relu_1.with_attr('Composite', tir.StringImm('add_relu'))
add_relu_call_1 = relay.Call(add_relu_1, [inputs[0], inputs[1]])
x1 = relay.var('x1')
y1 = relay.var('y1')
add_relu_2 = relay.add(x1, y1)
add_relu_2 = relay.nn.relu(add_relu_2)
add_relu_2 = relay.Function([x1, y1], add_relu_2)
add_relu_2 = add_relu_2.set_attribute('Composite', tir.StringImm('add_relu'))
add_relu_2 = add_relu_2.with_attr('Composite', tir.StringImm('add_relu'))
add_relu_call_2 = relay.Call(add_relu_2, [inputs[2], inputs[3]])
x2 = relay.var('x2')
y2 = relay.var('y2')
......@@ -627,7 +627,7 @@ def test_multiple_input_subgraphs():
sub = relay.subtract(x2, y2)
add_sub_mul = relay.multiply(add, sub)
add_sub_mul = relay.Function([x2, y2], add_sub_mul)
add_sub_mul = add_sub_mul.set_attribute('Composite', tir.StringImm('add_sub_mul'))
add_sub_mul = add_sub_mul.with_attr('Composite', tir.StringImm('add_sub_mul'))
add_sub_mul_call = relay.Call(add_sub_mul, [add_relu_call_1, add_relu_call_2])
return relay.Function(inputs, add_sub_mul_call)
......@@ -640,7 +640,7 @@ def test_multiple_input_subgraphs():
add_relu = relay.add(x, y)
add_relu = relay.nn.relu(add_relu)
add_relu = relay.Function([x, y], add_relu)
add_relu = add_relu.set_attribute('Composite', tir.StringImm('add_relu'))
add_relu = add_relu.with_attr('Composite', tir.StringImm('add_relu'))
add_relu_call = relay.Call(add_relu, [inputs[i*2], inputs[i*2+1]])
add_relu_calls.append(add_relu_call)
......
......@@ -303,11 +303,11 @@ def test_extern_ccompiler_default_ops():
add = x0 + y0
# Function that uses C compiler
func = relay.Function([x0, y0], add)
func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler",
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler",
tvm.tir.StringImm("ccompiler"))
func = func.set_attribute("ExternalSymbol",
func = func.with_attr("ExternalSymbol",
tvm.tir.StringImm("ccompiler_0"))
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
......@@ -318,7 +318,7 @@ def test_extern_ccompiler_default_ops():
exp = relay.exp(p0)
concat = relay.concatenate([log, exp], axis=0)
fused_func = relay.Function([p0], concat)
fused_func = fused_func.set_attribute("Primitive",
fused_func = fused_func.with_attr("Primitive",
tvm.tir.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call])
main = relay.Function([x, y], fused_call)
......@@ -390,10 +390,10 @@ def test_extern_dnnl():
out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
func = relay.Function([data0, input0, input1], out)
func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler", tvm.tir.StringImm("dnnl"))
func = func.set_attribute("ExternalSymbol",
func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func = func.with_attr("Compiler", tvm.tir.StringImm("dnnl"))
func = func.with_attr("ExternalSymbol",
tvm.tir.StringImm("dnnl_0"))
glb_var = relay.GlobalVar("dnnl_0")
mod = tvm.IRModule()
......@@ -516,11 +516,11 @@ def test_function_lifting():
bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
bn.astuple())
func0 = func0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.set_attribute("Compiler",
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_compiler"))
func0 = func0.set_attribute("ExternalSymbol",
func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_compiler_0"))
gv0 = relay.GlobalVar("test_compiler_0")
mod[gv0] = func0
......@@ -535,11 +535,11 @@ def test_function_lifting():
channels=16,
padding=(1, 1))
func1 = relay.Function([data1, weight1], conv)
func1 = func1.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func1 = func1.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.set_attribute("Compiler",
func1 = func1.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func1 = func1.with_attr("Compiler",
tvm.tir.StringImm("test_compiler"))
func1 = func1.set_attribute("ExternalSymbol",
func1 = func1.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_compiler_1"))
gv1 = relay.GlobalVar("test_compiler_1")
mod[gv1] = func1
......@@ -609,11 +609,11 @@ def test_function_lifting_inline():
bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
bn.astuple())
func0 = func0.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.set_attribute("Compiler",
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler",
tvm.tir.StringImm("test_compiler"))
func0 = func0.set_attribute("ExternalSymbol",
func0 = func0.with_attr("ExternalSymbol",
tvm.tir.StringImm("test_compiler_0"))
# main function
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
def test_make_attrs():
try:
x = tvm.ir.make_node("attrs.TestAttrs", unknown_key=1, name="xx")
assert False
except tvm.error.TVMError as e:
assert str(e).find("unknown_key") != -1
try:
x = tvm.ir.make_node("attrs.TestAttrs", axis=100, name="xx")
assert False
except tvm.error.TVMError as e:
assert str(e).find("upper bound") != -1
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4))
assert x.name == "xx"
assert x.padding[0].value == 3
assert x.padding[1].value == 4
assert x.axis == 10
def test_dict_attrs():
dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert dattr.x.value == 1
datrr = tvm.ir.load_json(tvm.ir.save_json(dattr))
assert dattr.name.value == "xyz"
assert isinstance(dattr, tvm.ir.DictAttrs)
assert "name" in dattr
assert dattr["x"].value == 1
assert len(dattr) == 4
assert len([x for x in dattr.keys()]) == 4
assert len(dattr.items()) == 4
if __name__ == "__main__":
test_make_attrs()
test_dict_attrs()
......@@ -54,33 +54,6 @@ def test_make_node():
assert AA.value_index == A.value_index
def test_make_attrs():
try:
x = tvm.ir.make_node("attrs.TestAttrs", unknown_key=1, name="xx")
assert False
except tvm.error.TVMError as e:
assert str(e).find("unknown_key") != -1
try:
x = tvm.ir.make_node("attrs.TestAttrs", axis=100, name="xx")
assert False
except tvm.error.TVMError as e:
assert str(e).find("upper bound") != -1
x = tvm.ir.make_node("attrs.TestAttrs", name="xx", padding=(3,4))
assert x.name == "xx"
assert x.padding[0].value == 3
assert x.padding[1].value == 4
assert x.axis == 10
dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert dattr.x.value == 1
datrr = tvm.ir.load_json(tvm.ir.save_json(dattr))
assert dattr.name.value == "xyz"
def test_make_sum():
A = te.placeholder((2, 10), name='A')
k = te.reduce_axis((0,10), "k")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment