Unverified Commit 531efd6f by Tianqi Chen Committed by GitHub

[NODE] Enable global singleton object, allow set_body_typed in function…

[NODE] Enable global singleton object, allow set_body_typed in function registry, default fallback of IRPrinter. (#1652)
parent 771d895d
Subproject commit a0b9563f45719553adf4d39fe3c14db1af0e1f40 Subproject commit 6f64f7866747a2a81bec84aea9bde0479c5b72c1
...@@ -68,26 +68,72 @@ inline NodeType LoadJSON(const std::string& json_str) { ...@@ -68,26 +68,72 @@ inline NodeType LoadJSON(const std::string& json_str) {
return NodeType(LoadJSON_(json_str)); return NodeType(LoadJSON_(json_str));
} }
/*! \brief typedef the factory function of data iterator */
using NodeFactory = std::function<std::shared_ptr<Node> ()>;
/*! /*!
* \brief Registry entry for NodeFactory * \brief Registry entry for NodeFactory.
*
* There are two types of Nodes that can be serialized.
* The normal node requires a registration a creator function that
* constructs an empty Node of the corresponding type.
*
* The global singleton(e.g. global operator) where only global_key need to be serialized,
* in this case, FGlobalKey need to be defined.
*/
struct NodeFactoryReg {
/*!
* \brief creator function.
* \param global_key Key that identifies a global single object.
* If this is not empty then FGlobalKey
* \return The created function.
*/
using FCreate = std::function<std::shared_ptr<Node>(const std::string& global_key)>;
/*!
* \brief Global key function, only needed by global objects.
* \param node The node pointer.
* \return node The global key to the node.
*/ */
struct NodeFactoryReg using FGlobalKey = std::function<std::string(const Node* node)>;
: public dmlc::FunctionRegEntryBase<NodeFactoryReg, /*! \brief registered name */
NodeFactory> { std::string name;
/*!
* \brief The creator function
*/
FCreate fcreator = nullptr;
/*!
* \brief The global key function.
*/
FGlobalKey fglobal_key = nullptr;
// setter of creator
NodeFactoryReg& set_creator(FCreate f) { // NOLINT(*)
this->fcreator = f;
return *this;
}
// setter of creator
NodeFactoryReg& set_global_key(FGlobalKey f) { // NOLINT(*)
this->fglobal_key = f;
return *this;
}
// global registry singleton
TVM_DLL static ::dmlc::Registry<::tvm::NodeFactoryReg> *Registry();
}; };
/*!
* \brief Register a Node type
* \note This is necessary to enable serialization of the Node.
*/
#define TVM_REGISTER_NODE_TYPE(TypeName) \ #define TVM_REGISTER_NODE_TYPE(TypeName) \
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \ static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
::dmlc::Registry<::tvm::NodeFactoryReg>::Get()->__REGISTER__(TypeName::_type_key) \ ::tvm::NodeFactoryReg::Registry()->__REGISTER__(TypeName::_type_key) \
.set_body([]() { return std::make_shared<TypeName>(); }) .set_creator([](const std::string&) { return std::make_shared<TypeName>(); })
#define TVM_STRINGIZE_DETAIL(x) #x
#define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x)
#define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__))
/*!
* \brief Macro to include current line as string
*/
#define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__)
TVM_DLL::dmlc::Registry<::tvm::NodeFactoryReg > * GetTVMNodeFactoryRegistry();
#define TVM_EXTERNAL_REGISTER_NODE_TYPE(TypeName) \
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
::tvm::GetTVMNodeFactoryRegistry()->__REGISTER__(TypeName::_type_key) \
.set_body([]() { return std::make_shared<TypeName>(); })
} // namespace tvm } // namespace tvm
#endif // TVM_BASE_H_ #endif // TVM_BASE_H_
...@@ -48,6 +48,24 @@ class Registry { ...@@ -48,6 +48,24 @@ class Registry {
return set_body(PackedFunc(f)); return set_body(PackedFunc(f));
} }
/*! /*!
* \brief set the body of the function to be TypedPackedFunc.
*
* \code
*
* TVM_REGISTER_API("addone")
* .set_body_typed<int(int)>([](int x) { return x + 1; });
*
* \endcode
*
* \param f The body of the function.
* \tparam FType the signature of the function.
* \tparam FLambda The type of f.
*/
template<typename FType, typename FLambda>
Registry& set_body_typed(FLambda f) {
return set_body(TypedPackedFunc<FType>(f).packed());
}
/*!
* \brief Register a function with given name * \brief Register a function with given name
* \param name The name of the function. * \param name The name of the function.
* \param override Whether allow oveeride existing function. * \param override Whether allow oveeride existing function.
......
...@@ -100,6 +100,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict") ...@@ -100,6 +100,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
*rv = ret; *rv = ret;
}); });
TVM_EXTERNAL_REGISTER_NODE_TYPE(NDArrayWrapperNode); TVM_REGISTER_NODE_TYPE(NDArrayWrapperNode);
} // namespace compiler } // namespace compiler
} // namespace nnvm } // namespace nnvm
...@@ -24,21 +24,13 @@ TVM_REGISTER_API("_raw_ptr") ...@@ -24,21 +24,13 @@ TVM_REGISTER_API("_raw_ptr")
}); });
TVM_REGISTER_API("_save_json") TVM_REGISTER_API("_save_json")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed<std::string(NodeRef)>(SaveJSON);
*ret = SaveJSON(args[0]);
});
TVM_REGISTER_API("_load_json") TVM_REGISTER_API("_load_json")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed<NodeRef(std::string)>(LoadJSON<NodeRef>);
*ret = LoadJSON<NodeRef>(args[0]);
});
TVM_REGISTER_API("_TVMSetStream") TVM_REGISTER_API("_TVMSetStream")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
TVMSetStream(args[0], args[1], args[2]); TVMSetStream(args[0], args[1], args[2]);
}); });
TVM_DLL::dmlc::Registry<::tvm::NodeFactoryReg > * GetTVMNodeFactoryRegistry() {
return ::dmlc::Registry<::tvm::NodeFactoryReg>::Get();
}
} // namespace tvm } // namespace tvm
...@@ -20,6 +20,10 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg); ...@@ -20,6 +20,10 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
namespace tvm { namespace tvm {
::dmlc::Registry<NodeFactoryReg>* NodeFactoryReg::Registry() {
return ::dmlc::Registry<NodeFactoryReg>::Get();
}
inline std::string Type2String(const Type& t) { inline std::string Type2String(const Type& t) {
if (t.code() ==Type::Handle) return "handle"; if (t.code() ==Type::Handle) return "handle";
std::ostringstream os; std::ostringstream os;
...@@ -115,6 +119,8 @@ using AttrMap = std::map<std::string, std::string>; ...@@ -115,6 +119,8 @@ using AttrMap = std::map<std::string, std::string>;
struct JSONNode { struct JSONNode {
// The type key of the data // The type key of the data
std::string type_key; std::string type_key;
// The global key for global object
std::string global_key;
// the attributes // the attributes
AttrMap attrs; AttrMap attrs;
// container keys // container keys
...@@ -125,6 +131,9 @@ struct JSONNode { ...@@ -125,6 +131,9 @@ struct JSONNode {
void Save(dmlc::JSONWriter *writer) const { void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject(); writer->BeginObject();
writer->WriteObjectKeyValue("type_key", type_key); writer->WriteObjectKeyValue("type_key", type_key);
if (global_key.size() != 0) {
writer->WriteObjectKeyValue("global_key", global_key);
}
if (attrs.size() != 0) { if (attrs.size() != 0) {
writer->WriteObjectKeyValue("attrs", attrs); writer->WriteObjectKeyValue("attrs", attrs);
} }
...@@ -140,9 +149,11 @@ struct JSONNode { ...@@ -140,9 +149,11 @@ struct JSONNode {
void Load(dmlc::JSONReader *reader) { void Load(dmlc::JSONReader *reader) {
attrs.clear(); attrs.clear();
data.clear(); data.clear();
global_key.clear();
type_key.clear(); type_key.clear();
dmlc::JSONObjectReadHelper helper; dmlc::JSONObjectReadHelper helper;
helper.DeclareOptionalField("type_key", &type_key); helper.DeclareOptionalField("type_key", &type_key);
helper.DeclareOptionalField("global_key", &global_key);
helper.DeclareOptionalField("attrs", &attrs); helper.DeclareOptionalField("attrs", &attrs);
helper.DeclareOptionalField("keys", &keys); helper.DeclareOptionalField("keys", &keys);
helper.DeclareOptionalField("data", &data); helper.DeclareOptionalField("data", &data);
...@@ -195,6 +206,12 @@ class JSONAttrGetter : public AttrVisitor { ...@@ -195,6 +206,12 @@ class JSONAttrGetter : public AttrVisitor {
return; return;
} }
node_->type_key = node->type_key(); node_->type_key = node->type_key();
// sepcially handle global object
auto* f = dmlc::Registry<NodeFactoryReg>::Find(node_->type_key);
if (f->fglobal_key != nullptr) {
node_->global_key = f->fglobal_key(node);
return;
}
node_->attrs.clear(); node_->attrs.clear();
node_->data.clear(); node_->data.clear();
if (node->is_type<ArrayNode>()) { if (node->is_type<ArrayNode>()) {
...@@ -403,7 +420,7 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) { ...@@ -403,7 +420,7 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) {
auto* f = dmlc::Registry<NodeFactoryReg>::Find(jnode.type_key); auto* f = dmlc::Registry<NodeFactoryReg>::Find(jnode.type_key);
CHECK(f != nullptr) CHECK(f != nullptr)
<< "Node type \'" << jnode.type_key << "\' is not registered in TVM"; << "Node type \'" << jnode.type_key << "\' is not registered in TVM";
nodes.emplace_back(f->body()); nodes.emplace_back(f->fcreator(jnode.global_key));
} else { } else {
nodes.emplace_back(std::shared_ptr<Node>()); nodes.emplace_back(std::shared_ptr<Node>());
} }
...@@ -415,8 +432,12 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) { ...@@ -415,8 +432,12 @@ std::shared_ptr<Node> LoadJSON_(std::string json_str) {
for (size_t i = 0; i < nodes.size(); ++i) { for (size_t i = 0; i < nodes.size(); ++i) {
setter.node_ = &jgraph.nodes[i]; setter.node_ = &jgraph.nodes[i];
// do not need to recover content of global singleton object
// they are registered via the environment
if (setter.node_->global_key.length() == 0) {
setter.Set(nodes[i].get()); setter.Set(nodes[i].get());
} }
}
return nodes.at(jgraph.root); return nodes.at(jgraph.root);
} }
...@@ -493,11 +514,14 @@ void InitNodeByPackedArgs(Node* n, const TVMArgs& args) { ...@@ -493,11 +514,14 @@ void InitNodeByPackedArgs(Node* n, const TVMArgs& args) {
// key1, value1, ..., key_n, value_n // key1, value1, ..., key_n, value_n
void MakeNode(const TVMArgs& args, TVMRetValue* rv) { void MakeNode(const TVMArgs& args, TVMRetValue* rv) {
std::string type_key = args[0]; std::string type_key = args[0];
std::string empty_str;
auto* f = dmlc::Registry<NodeFactoryReg>::Find(type_key); auto* f = dmlc::Registry<NodeFactoryReg>::Find(type_key);
CHECK(f != nullptr) CHECK(f != nullptr)
<< "Node type \'" << type_key << "\' is not registered in TVM"; << "Node type \'" << type_key << "\' is not registered in TVM";
TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1); TVMArgs kwargs(args.values + 1, args.type_codes + 1, args.size() - 1);
std::shared_ptr<Node> n = f->body(); CHECK(f->fglobal_key == nullptr)
<< "Cannot make node type \'" << type_key << "\' with global_key.";
std::shared_ptr<Node> n = f->fcreator(empty_str);
if (n->derived_from<BaseAttrsNode>()) { if (n->derived_from<BaseAttrsNode>()) {
static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs); static_cast<BaseAttrsNode*>(n.get())->InitByPackedArgs(kwargs);
} else { } else {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment