Unverified Commit a6724b6e by Tianqi Chen Committed by GitHub

[NODE] Enable EnvFunc to serialize global function as node (#1721)

parent 43126602
/*!
* Copyright (c) 2017 by Contributors
* \file tvm/api_registry.h
* \brief This files include necessary headers to
* be used to register an global API function.
* \brief This file contains utilities related to
* the TVM's global function registry.
*/
#ifndef TVM_API_REGISTRY_H_
#define TVM_API_REGISTRY_H_
#include <string>
#include "base.h"
#include "packed_func_ext.h"
#include "runtime/registry.h"
namespace tvm {
/*!
* \brief Register an API function globally.
* It simply redirects to TVM_REGISTER_GLOBAL
......@@ -24,4 +26,113 @@
*/
#define TVM_REGISTER_API(OpName) TVM_REGISTER_GLOBAL(OpName)
/*!
* \brief Node container of EnvFunc
* \sa EnvFunc
*/
class EnvFuncNode : public Node {
public:
/*! \brief Unique name of the global function */
std::string name;
/*! \brief The internal packed function */
PackedFunc func;
/*! \brief constructor */
EnvFuncNode() {}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
}
static constexpr const char* _type_key = "EnvFunc";
TVM_DECLARE_NODE_TYPE_INFO(EnvFuncNode, Node);
};
/*!
* \brief A serializable function backed by TVM's global environment.
*
* This is a wrapper to enable serializable global PackedFunc.
* An EnvFunc is saved by its name in the global registry
* under the assumption that the same function is registered during load.
*/
class EnvFunc : public NodeRef {
public:
EnvFunc() {}
explicit EnvFunc(std::shared_ptr<Node> n) : NodeRef(n) {}
/*! \return The internal global function pointer */
const EnvFuncNode* operator->() const {
return static_cast<EnvFuncNode*>(node_.get());
}
/*!
* \brief Invoke the function.
* \param args The arguments
* \returns The return value.
*/
template<typename... Args>
runtime::TVMRetValue operator()(Args&&... args) const {
const EnvFuncNode* n = operator->();
CHECK(n != nullptr);
return n->func(std::forward<Args>(args)...);
}
/*!
* \brief Get a global function based on the name.
* \param name The name of the global function.
* \return The created global function.
* \note The function can be unique
*/
TVM_DLL static EnvFunc Get(const std::string& name);
/*! \brief specify container node */
using ContainerType = EnvFuncNode;
};
/*!
* \brief Please refer to \ref TypedEnvFuncAnchor "TypedEnvFunc<R(Args..)>"
*/
template<typename FType>
class TypedEnvFunc;
/*!
* \anchor TypedEnvFuncAnchor
* \brief A typed version of EnvFunc.
* It is backed by a GlobalFuncNode internally.
*
* \tparam R The return value of the function.
* \tparam Args The argument signature of the function.
* \sa EnvFunc
*/
template<typename R, typename... Args>
class TypedEnvFunc<R(Args...)> : public NodeRef {
public:
/*! \brief short hand for this function type */
using TSelf = TypedEnvFunc<R(Args...)>;
TypedEnvFunc() {}
explicit TypedEnvFunc(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief Assign global function to a TypedEnvFunc
* \param other Another global function.
* \return reference to self.
*/
TSelf& operator=(const EnvFunc& other) {
this->node_ = other.node_;
return *this;
}
/*! \return The internal global function pointer */
const EnvFuncNode* operator->() const {
return static_cast<EnvFuncNode*>(node_.get());
}
/*!
* \brief Invoke the function.
* \param args The arguments
* \returns The return value.
*/
R operator()(Args... args) const {
const EnvFuncNode* n = operator->();
CHECK(n != nullptr);
return runtime::detail::typed_packed_call_dispatcher<R>
::run(n->func, std::forward<Args>(args)...);
}
/*! \brief specify container node */
using ContainerType = EnvFuncNode;
};
} // namespace tvm
#endif // TVM_API_REGISTRY_H_
......@@ -257,6 +257,14 @@ class TypedPackedFunc<R(Args...)> {
const PackedFunc& packed() const {
return packed_;
}
/*! \return Whether the packed function is nullptr */
bool operator==(std::nullptr_t null) const {
return packed_ == nullptr;
}
/*! \return Whether the packed function is not nullptr */
bool operator!=(std::nullptr_t null) const {
return packed_ != nullptr;
}
private:
friend class TVMRetValue;
......
......@@ -45,6 +45,28 @@ def const(value, dtype=None):
return _api_internal._const(value, dtype)
def get_env_func(name):
"""Get an EnvFunc by a global name.
Parameters
----------
name: str
The name of the global function.
Returns
-------
env_func : EnvFunc
The result env function.
Note
----
EnvFunc is a Node wrapper around
global function that can be serialized via its name.
This can be used to serialize function field in the language.
"""
return _api_internal._EnvFuncGet(name)
def convert(value):
"""Convert value to TVM node or function.
......
......@@ -28,6 +28,20 @@ class Array(NodeBase):
@register_node
class EnvFunc(NodeBase):
"""Environment function.
This is a global function object that can be serialized by its name.
"""
def __call__(self, *args):
return _api_internal._EnvFuncCall(self, *args)
@property
def func(self):
return _api_internal._EnvFuncGetPackedFunc(self)
@register_node
class Map(NodeBase):
"""Map container of TVM.
......
......@@ -14,6 +14,7 @@ struct TestAttrs : public AttrsNode<TestAttrs> {
int axis;
std::string name;
Array<Expr> padding;
TypedEnvFunc<int(int)> func;
TVM_DECLARE_ATTRS(TestAttrs, "attrs.TestAttrs") {
TVM_ATTR_FIELD(axis)
......@@ -26,6 +27,9 @@ struct TestAttrs : public AttrsNode<TestAttrs> {
TVM_ATTR_FIELD(padding)
.describe("padding of input")
.set_default(Array<Expr>({0, 0}));
TVM_ATTR_FIELD(func)
.describe("some random env function")
.set_default(TypedEnvFunc<int(int)>(nullptr));
}
};
......
/*!
* Copyright (c) 2018 by Contributors
* \file api_registry.cc
*/
#include <tvm/api_registry.h>
namespace tvm {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<EnvFuncNode>([](const EnvFuncNode *op, IRPrinter *p) {
p->stream << "EnvFunc(" << op->name << ")";
});
std::shared_ptr<EnvFuncNode> CreateEnvNode(const std::string& name) {
auto* f = runtime::Registry::Get(name);
CHECK(f != nullptr) << "Cannot find global function \'" << name << '\'';
std::shared_ptr<EnvFuncNode> n = std::make_shared<EnvFuncNode>();
n->func = *f;
n->name = name;
return n;
}
EnvFunc EnvFunc::Get(const std::string& name) {
return EnvFunc(CreateEnvNode(name));
}
TVM_REGISTER_API("_EnvFuncGet")
.set_body_typed<EnvFunc(const std::string& name)>(EnvFunc::Get);
TVM_REGISTER_API("_EnvFuncCall")
.set_body([](TVMArgs args, TVMRetValue* rv) {
EnvFunc env = args[0];
CHECK_GE(args.size(), 1);
env->func.CallPacked(TVMArgs(args.values + 1,
args.type_codes + 1,
args.size() - 1), rv);
});
TVM_REGISTER_API("_EnvFuncGetPackedFunc")
.set_body_typed<PackedFunc(const EnvFunc& n)>([](const EnvFunc&n) {
return n->func;
});
TVM_REGISTER_NODE_TYPE(EnvFuncNode)
.set_creator(CreateEnvNode)
.set_global_key([](const Node* n) {
return static_cast<const EnvFuncNode*>(n)->name;
});
} // namespace tvm
......@@ -56,11 +56,14 @@ def test_make_attrs():
assert x.padding[1].value == 4
assert x.axis == 10
dattr = tvm.make.node("DictAttrs", x=1, y=10, name="xyz", padding=(0,0))
assert dattr.x.value == 1
datrr = tvm.load_json(tvm.save_json(dattr))
assert dattr.name.value == "xyz"
def test_make_sum():
A = tvm.placeholder((2, 10), name='A')
k = tvm.reduce_axis((0,10), "k")
......@@ -70,7 +73,33 @@ def test_make_sum():
assert B.op.body[0].combiner is not None
assert BB.op.body[0].combiner is not None
def test_env_func():
@tvm.register_func("test.env_func")
def test(x):
return x + 1
f = tvm.get_global_func("test.env_func")
x = tvm.get_env_func("test.env_func")
assert x.name == "test.env_func"
json_str = tvm.save_json([x])
y = tvm.load_json(json_str)[0]
assert y.name == x.name
assert y(1) == 2
assert y.func(1) == 2
x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4), func=y)
assert x.name == "xx"
assert x.padding[0].value == 3
assert x.padding[1].value == 4
assert x.axis == 10
x = tvm.load_json(tvm.save_json(x))
assert isinstance(x.func, tvm.container.EnvFunc)
assert x.func(10) == 11
if __name__ == "__main__":
test_env_func()
test_make_attrs()
test_make_node()
test_make_smap()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment