Commit 458d4f19 by ziheng Committed by Tianqi Chen

[EXECUTOR] Add GraphHandle (#285)

* [GRAPH] Add GraphHandle

* Move to apps/graph_executor
parent 7bc5b5e5
Subproject commit 36ecc1eec0898411ae70e98c315b03247d5fb4a0 Subproject commit 326e2fa18734f0592d257da6b8cfaae90a499c5c
/*!
* Copyright (c) 2017 by Contributors
* \file graph_handle.cc
*/
#include <tvm/packed_func_ext.h>
#include "./graph_handle.h"
namespace tvm {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GraphHandleNode>([](const GraphHandleNode *op, IRPrinter *p) {
p->stream << "graph-handle("
<< "handle=0x" << std::hex
<< reinterpret_cast<uint64_t>(op->graph_handle) << ")";
});
TVM_REGISTER_NODE_TYPE(GraphHandleNode);
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file graph.h
* \brief Data structure about computational graph.
*/
#ifndef TVM_GRAPH_HANDLE_H_
#define TVM_GRAPH_HANDLE_H_
#include <string>
#include <tvm/base.h>
namespace tvm {
/*!
* \brief Computational graph handle.
* Use GraphHandle as its container type
*/
struct GraphHandleNode : public Node {
void *graph_handle;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("graph_handle", &graph_handle);
}
static constexpr const char* _type_key = "GraphHandle";
TVM_DECLARE_NODE_TYPE_INFO(GraphHandleNode, Node);
};
/*! \brief Defines graph handle */
TVM_DEFINE_NODE_REF(GraphHandle, GraphHandleNode);
} // namespace tvm
#endif // TVM_GRAPH_HANDLE_H_
...@@ -523,12 +523,12 @@ nnvm::Graph PruneGraph(nnvm::Graph src) { ...@@ -523,12 +523,12 @@ nnvm::Graph PruneGraph(nnvm::Graph src) {
if (can_be_pruned) { if (can_be_pruned) {
pruned.emplace(n.get()); pruned.emplace(n.get());
} else { } else {
// scan again to find edge nodes // scan again to find edge nodes, skip variables
for (auto& e : n->inputs) { for (auto& e : n->inputs) {
if (pruned.count(e.node.get())) { if (!e.node->is_variable() && pruned.count(e.node.get())) {
if (!entry_var.count(e)) { if (!entry_var.count(e)) {
nnvm::NodePtr var = nnvm::Node::Create(); nnvm::NodePtr var = nnvm::Node::Create();
var->attrs.name = e.node->attrs.name + "_" + std::to_string(e.index); var->attrs.name = e.node->attrs.name + "_output" + std::to_string(e.index);
entry_var.emplace(e, var); entry_var.emplace(e, var);
} }
e = nnvm::NodeEntry{entry_var.at(e), 0, 0}; e = nnvm::NodeEntry{entry_var.at(e), 0, 0};
...@@ -542,6 +542,7 @@ nnvm::Graph PruneGraph(nnvm::Graph src) { ...@@ -542,6 +542,7 @@ nnvm::Graph PruneGraph(nnvm::Graph src) {
std::vector<std::string> output_names; std::vector<std::string> output_names;
output_names.reserve(entry_var.size()); output_names.reserve(entry_var.size());
for (auto kv : entry_var) { for (auto kv : entry_var) {
if (kv.first.node->is_variable()) continue;
pre_graph.outputs.emplace_back(kv.first); pre_graph.outputs.emplace_back(kv.first);
output_names.emplace_back(kv.second->attrs.name); output_names.emplace_back(kv.second->attrs.name);
} }
......
...@@ -51,6 +51,9 @@ struct APIAttrGetter : public AttrVisitor { ...@@ -51,6 +51,9 @@ struct APIAttrGetter : public AttrVisitor {
void Visit(const char* key, bool* value) final { void Visit(const char* key, bool* value) final {
if (skey == key) *ret = static_cast<int64_t>(value[0]); if (skey == key) *ret = static_cast<int64_t>(value[0]);
} }
void Visit(const char* key, void** value) final {
if (skey == key) *ret = static_cast<void*>(value[0]);
}
void Visit(const char* key, Type* value) final { void Visit(const char* key, Type* value) final {
if (skey == key) *ret = value[0]; if (skey == key) *ret = value[0];
} }
...@@ -83,6 +86,9 @@ struct APIAttrDir : public AttrVisitor { ...@@ -83,6 +86,9 @@ struct APIAttrDir : public AttrVisitor {
void Visit(const char* key, int* value) final { void Visit(const char* key, int* value) final {
names->push_back(key); names->push_back(key);
} }
void Visit(const char* key, void** value) final {
names->push_back(key);
}
void Visit(const char* key, Type* value) final { void Visit(const char* key, Type* value) final {
names->push_back(key); names->push_back(key);
} }
......
...@@ -59,6 +59,7 @@ class NodeIndexer : public AttrVisitor { ...@@ -59,6 +59,7 @@ class NodeIndexer : public AttrVisitor {
void Visit(const char* key, int* value) final {} void Visit(const char* key, int* value) final {}
void Visit(const char* key, bool* value) final {} void Visit(const char* key, bool* value) final {}
void Visit(const char* key, std::string* value) final {} void Visit(const char* key, std::string* value) final {}
void Visit(const char* key, void** value) final {}
void Visit(const char* key, Type* value) final {} void Visit(const char* key, Type* value) final {}
void Visit(const char* key, NodeRef* value) final { void Visit(const char* key, NodeRef* value) final {
MakeIndex(value->node_.get()); MakeIndex(value->node_.get());
...@@ -148,6 +149,9 @@ class JSONAttrGetter : public AttrVisitor { ...@@ -148,6 +149,9 @@ class JSONAttrGetter : public AttrVisitor {
void Visit(const char* key, std::string* value) final { void Visit(const char* key, std::string* value) final {
node_->attrs[key] = *value; node_->attrs[key] = *value;
} }
void Visit(const char* key, void** value) final {
LOG(FATAL) << "not allowed to serialize a pointer";
}
void Visit(const char* key, Type* value) final { void Visit(const char* key, Type* value) final {
node_->attrs[key] = Type2String(*value); node_->attrs[key] = Type2String(*value);
} }
...@@ -223,6 +227,9 @@ class JSONAttrSetter : public AttrVisitor { ...@@ -223,6 +227,9 @@ class JSONAttrSetter : public AttrVisitor {
void Visit(const char* key, std::string* value) final { void Visit(const char* key, std::string* value) final {
*value = GetValue(key); *value = GetValue(key);
} }
void Visit(const char* key, void** value) final {
LOG(FATAL) << "not allowed to deserialize a pointer";
}
void Visit(const char* key, Type* value) final { void Visit(const char* key, Type* value) final {
std::string stype = GetValue(key); std::string stype = GetValue(key);
*value = String2Type(stype); *value = String2Type(stype);
...@@ -371,6 +378,9 @@ class NodeAttrSetter : public AttrVisitor { ...@@ -371,6 +378,9 @@ class NodeAttrSetter : public AttrVisitor {
void Visit(const char* key, std::string* value) final { void Visit(const char* key, std::string* value) final {
SetValue(key, value); SetValue(key, value);
} }
void Visit(const char* key, void** value) final {
SetValue(key, value);
}
void Visit(const char* key, Type* value) final { void Visit(const char* key, Type* value) final {
SetValue(key, value); SetValue(key, value);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment