Commit 458d4f19 by ziheng Committed by Tianqi Chen

[EXECUTOR] Add GraphHandle (#285)

* [GRAPH] Add GraphHandle

* Move to apps/graph_executor
parent 7bc5b5e5
Subproject commit 36ecc1eec0898411ae70e98c315b03247d5fb4a0
Subproject commit 326e2fa18734f0592d257da6b8cfaae90a499c5c
/*!
* Copyright (c) 2017 by Contributors
* \file graph_handle.cc
*/
#include <tvm/packed_func_ext.h>
#include "./graph_handle.h"
namespace tvm {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GraphHandleNode>([](const GraphHandleNode *op, IRPrinter *p) {
p->stream << "graph-handle("
<< "handle=0x" << std::hex
<< reinterpret_cast<uint64_t>(op->graph_handle) << ")";
});
TVM_REGISTER_NODE_TYPE(GraphHandleNode);
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file graph.h
* \brief Data structure about computational graph.
*/
#ifndef TVM_GRAPH_HANDLE_H_
#define TVM_GRAPH_HANDLE_H_
#include <string>
#include <tvm/base.h>
namespace tvm {
/*!
* \brief Computational graph handle.
* Use GraphHandle as its container type
*/
struct GraphHandleNode : public Node {
void *graph_handle;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("graph_handle", &graph_handle);
}
static constexpr const char* _type_key = "GraphHandle";
TVM_DECLARE_NODE_TYPE_INFO(GraphHandleNode, Node);
};
/*! \brief Defines graph handle */
TVM_DEFINE_NODE_REF(GraphHandle, GraphHandleNode);
} // namespace tvm
#endif // TVM_GRAPH_HANDLE_H_
......@@ -523,12 +523,12 @@ nnvm::Graph PruneGraph(nnvm::Graph src) {
if (can_be_pruned) {
pruned.emplace(n.get());
} else {
// scan again to find edge nodes
// scan again to find edge nodes, skip variables
for (auto& e : n->inputs) {
if (pruned.count(e.node.get())) {
if (!e.node->is_variable() && pruned.count(e.node.get())) {
if (!entry_var.count(e)) {
nnvm::NodePtr var = nnvm::Node::Create();
var->attrs.name = e.node->attrs.name + "_" + std::to_string(e.index);
var->attrs.name = e.node->attrs.name + "_output" + std::to_string(e.index);
entry_var.emplace(e, var);
}
e = nnvm::NodeEntry{entry_var.at(e), 0, 0};
......@@ -542,6 +542,7 @@ nnvm::Graph PruneGraph(nnvm::Graph src) {
std::vector<std::string> output_names;
output_names.reserve(entry_var.size());
for (auto kv : entry_var) {
if (kv.first.node->is_variable()) continue;
pre_graph.outputs.emplace_back(kv.first);
output_names.emplace_back(kv.second->attrs.name);
}
......
......@@ -51,6 +51,9 @@ struct APIAttrGetter : public AttrVisitor {
void Visit(const char* key, bool* value) final {
if (skey == key) *ret = static_cast<int64_t>(value[0]);
}
void Visit(const char* key, void** value) final {
if (skey == key) *ret = static_cast<void*>(value[0]);
}
void Visit(const char* key, Type* value) final {
if (skey == key) *ret = value[0];
}
......@@ -83,6 +86,9 @@ struct APIAttrDir : public AttrVisitor {
void Visit(const char* key, int* value) final {
names->push_back(key);
}
void Visit(const char* key, void** value) final {
names->push_back(key);
}
void Visit(const char* key, Type* value) final {
names->push_back(key);
}
......
......@@ -59,6 +59,7 @@ class NodeIndexer : public AttrVisitor {
void Visit(const char* key, int* value) final {}
void Visit(const char* key, bool* value) final {}
void Visit(const char* key, std::string* value) final {}
void Visit(const char* key, void** value) final {}
void Visit(const char* key, Type* value) final {}
void Visit(const char* key, NodeRef* value) final {
MakeIndex(value->node_.get());
......@@ -148,6 +149,9 @@ class JSONAttrGetter : public AttrVisitor {
void Visit(const char* key, std::string* value) final {
node_->attrs[key] = *value;
}
void Visit(const char* key, void** value) final {
LOG(FATAL) << "not allowed to serialize a pointer";
}
void Visit(const char* key, Type* value) final {
node_->attrs[key] = Type2String(*value);
}
......@@ -223,6 +227,9 @@ class JSONAttrSetter : public AttrVisitor {
void Visit(const char* key, std::string* value) final {
*value = GetValue(key);
}
void Visit(const char* key, void** value) final {
LOG(FATAL) << "not allowed to deserialize a pointer";
}
void Visit(const char* key, Type* value) final {
std::string stype = GetValue(key);
*value = String2Type(stype);
......@@ -371,6 +378,9 @@ class NodeAttrSetter : public AttrVisitor {
void Visit(const char* key, std::string* value) final {
SetValue(key, value);
}
void Visit(const char* key, void** value) final {
SetValue(key, value);
}
void Visit(const char* key, Type* value) final {
SetValue(key, value);
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment