Commit 343c19a5 by Tianqi Chen

[COMPILER] GraphHash based cache system, allow dump and query duplicated functions. (#30)

parent 300ae30a
......@@ -63,11 +63,11 @@ class Graph {
* \return The indexed graph.
* \sa IndexedGraph
*/
const IndexedGraph& indexed_graph();
const IndexedGraph& indexed_graph() const;
private:
// internal structure of indexed graph
std::shared_ptr<const IndexedGraph> indexed_graph_;
mutable std::shared_ptr<const IndexedGraph> indexed_graph_;
};
/*!
......
......@@ -41,6 +41,17 @@ inline std::string SaveJSON(Graph graph) {
return ret.GetAttr<std::string>("json");
}
/*!
* \brief Print graph ir
* \param graph The graph to be printed
* \return The graph ir string.
*/
inline std::string PrintGraphIR(Graph graph) {
Graph ret = ApplyPass(std::move(graph), "PrintGraphIR");
return ret.GetAttr<std::string>("graphir");
}
/*!
* \brief Add control flow dependencies between nodes.
*
......
......@@ -5,6 +5,7 @@ import tvm
from . import build_module
from . build_module import build, optimize, build_config
from . compile_engine import engine, graph_key
from .. import symbol as _symbol
from .. import graph as _graph
......@@ -14,5 +15,6 @@ from .registry import register_compute, register_schedule, register_pattern
from .. import top as _top
tvm.register_extension(_symbol.Symbol, _symbol.Symbol)
tvm.register_extension(_graph.Graph, _graph.Graph)
......@@ -184,7 +184,7 @@ def build(graph, target, shape, dtype="float32", params=None):
graph._set_json_attr("target", target, "str")
graph._set_json_attr("opt_level", cfg.opt_level, "int")
graph = graph.apply("InferShape").apply("InferType")
graph = graph.apply("GraphFusePartition").apply("GraphFuse")
graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile")
libmod = graph_attr._move_out_module(graph, "module")
return graph, libmod, params
......
# pylint: disable=invalid-name
"""Compiler engine interface to internal engine"""
import tvm
_list_cache_items = tvm.get_global_func("nnvm.compiler.ListCacheItems")
_clear_cache = tvm.get_global_func("nnvm.compiler.ClearCache")
_get_cache_item = tvm.get_global_func("nnvm.compiler.GetCacheItem")
_set_cache_item = tvm.get_global_func("nnvm.compiler.SetCacheItem")
_graph_key_get_graph = tvm.get_global_func("nnvm.compiler.GraphKeyGetGraph")
_make_graph_key = tvm.get_global_func("nnvm.compiler.MakeGraphKey")
@tvm.register_node
class GraphKey(tvm.node.NodeBase):
"""Key of a graph compilation context"""
@property
def graph(self):
return _graph_key_get_graph(self)
@tvm.register_node
class GraphCacheEntry(tvm.node.NodeBase):
"""CacheEntry of compilation into a TVM Function"""
pass
@tvm.register_node
class GraphFunc(tvm.node.NodeBase):
"""Compiled result of a graph into a TVM Function"""
pass
class Engine(object):
"""Global singleton compilation engine."""
def items(self):
"""List the available cache key value pairs.
Returns
-------
item_list : list of (GraphKey, GraphCacheEntry)
The existing cache items
"""
res = _list_cache_items()
assert len(res) % 2 == 0
return [(res[2*i], res[2*i+1]) for i in range(len(res)/2)]
def clear_cache(self):
"""Clear the existing cached functions."""
_clear_cache()
def __setitem__(self, key, value):
"""Clear the existing cached functions."""
if isinstance(value, GraphCacheEntry):
_set_cache_item(key, value.graph_func)
else:
_set_cache_item(key, value)
def __getitem__(self, key):
"""Clear the existing cached functions."""
return _get_cache_item(key)
def dump(self):
"""Return a string representation of engine dump
Returns
-------
dump : str
The dumped string representation
"""
items = self.items()
res = "====================================\n"
res += "CompilerEngine dump, %d items cached\n" % len(items)
for key, value in items:
res += "------------------------------------\n"
res += "target={}\n".format(key.target)
res += "inputs={}\n".format(key.inputs)
res += "use_count={}\n".format(value.use_count)
res += "func_name={}\n".format(value.graph_func.func_name)
res += key.graph.ir() + "\n"
res += "===================================\n"
return res
engine = Engine()
def graph_key(graph, inputs, target):
"""Construct a new graph key.
Parameters
----------
graph : Graph
The computation graph structure
inputs : list of Tensor(placeholder)
The input requirement to the graph.
target : str
The target of compilation.
"""
return _make_graph_key(graph, inputs, target)
"""Utilities for testcase"""
from .config import ctx_list
......@@ -2,7 +2,7 @@
import os
import tvm
def test_ctx_list():
def ctx_list():
"""Get context list for testcases"""
device_list = os.environ.get("NNVM_TEST_TARGETS", "")
device_list = (device_list.split(",") if device_list
......
/*!
* Copyright (c) 2017 by Contributors
* \file compile_engine.cc
* \brief The compile engine.
*/
#include <dmlc/common.h>
#include <tvm/ir.h>
#include <tvm/operation.h>
#include <nnvm/graph.h>
#include <nnvm/node.h>
#include <nnvm/pass_functions.h>
#include <nnvm/compiler/op_attr_types.h>
#include <mutex>
#include "./graph_hash.h"
#include "./compile_engine.h"
namespace nnvm {
namespace compiler {
using namespace tvm;
/*!
* \brief Get type flag from TVM Type
*
* \param type the tvm type.
* \return corresponding DLDataType
*/
int GetTypeFlag(tvm::Type type) {
if (type == tvm::Float(32)) return 0;
LOG(FATAL) << "cannot convert " << type;
return 0;
}
// convert from type flag to tvm type.
Type GetTVMType(int type_flag) {
if (type_flag == 0) return tvm::Float(32);
LOG(FATAL) << "unknown type_flag=" << type_flag;
return Float(32);
}
// internal compile engine
class CompileEngine {
public:
static CompileEngine* Global() {
static CompileEngine inst;
return &inst;
}
// lower graph possible get back an cached op.
GraphFunc Lower(Graph graph,
const Array<tvm::Tensor>& inputs,
const std::string& target,
const Op* schedule_op_key,
const NodeAttrs& schedule_op_attr) {
GraphKey key = GraphKeyNode::make(graph, inputs, target);
std::lock_guard<std::mutex> lock(mutex_);
auto it = cache_.find(key);
if (it != cache_.end()) {
++(it->second->use_count);
return it->second->graph_func;
}
GraphFunc f = DoLower(key->graph, key->inputs, key->target,
schedule_op_key, schedule_op_attr);
std::shared_ptr<GraphCacheEntryNode> n = std::make_shared<GraphCacheEntryNode>();
n->graph_func = f;
n->use_count = 1;
cache_[key] = GraphCacheEntry(n);
return f;
}
// List all items in the cache.
Array<NodeRef> ListCacheItems() {
std::lock_guard<std::mutex> lock(mutex_);
Array<NodeRef> items;
for (auto& kv : cache_) {
items.push_back(kv.first);
std::shared_ptr<GraphCacheEntryNode> n =
std::make_shared<GraphCacheEntryNode>(*(kv.second.operator->()));
items.push_back(GraphCacheEntry(n));
}
return items;
}
// Find the function given graph key.
GraphCacheEntry Find(const GraphKey& key) {
std::lock_guard<std::mutex> lock(mutex_);
auto it = cache_.find(key);
if (it != cache_.end()) {
return it->second;
} else {
return GraphCacheEntry();
}
}
// Find the function given graph key.
void Set(const GraphKey& key, GraphFunc func) {
std::lock_guard<std::mutex> lock(mutex_);
std::shared_ptr<GraphCacheEntryNode> n = std::make_shared<GraphCacheEntryNode>();
n->graph_func = func;
n->use_count = 1;
cache_[key] = GraphCacheEntry(n);
}
// Find the function given graph key.
void Clear() {
std::lock_guard<std::mutex> lock(mutex_);
cache_.clear();
}
// run the actual lowering process
GraphFunc DoLower(Graph graph,
const Array<tvm::Tensor>& inputs,
const std::string& target,
const Op* schedule_op_key,
const NodeAttrs& schedule_op_attr) {
// shape, type
static auto& fcompute =
nnvm::Op::GetAttr<FTVMCompute>("FTVMCompute");
static auto& fschedule =
nnvm::Op::GetAttr<FTVMSchedule>("FTVMSchedule");
std::vector<TShape> ishape;
std::vector<int> idtype;
for (const tvm::Tensor t : inputs) {
std::vector<dim_t> shape;
for (Expr v : t->shape) {
CHECK(v.as<tvm::ir::IntImm>());
shape.push_back(v.as<tvm::ir::IntImm>()->value);
}
ishape.emplace_back(TShape(shape.begin(), shape.end()));
idtype.emplace_back(GetTypeFlag(t->dtype));
}
graph = pass::InferShape(graph, ishape);
graph = pass::InferType(graph, idtype);
const ShapeVector& shape_vec = graph.GetAttr<ShapeVector>("shape");
const DTypeVector& dtype_vec = graph.GetAttr<DTypeVector>("dtype");
const IndexedGraph& idx = graph.indexed_graph();
CHECK_EQ(inputs.size(), idx.input_nodes().size());
std::vector<tvm::Tensor> tensor_vec(idx.num_node_entries());
for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
uint32_t nid = idx.input_nodes()[i];
tensor_vec[idx.entry_id(nid, 0)] = inputs[i];
}
std::ostringstream readable_name;
readable_name << "fuse";
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
Array<Tensor> inputs, out_info;
readable_name << "_" << inode.source->op()->name;
// input array
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
const tvm::Tensor& t = tensor_vec[idx.entry_id(e)];
CHECK(t.defined());
inputs.push_back(t);
}
// output hint
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
Array<Expr> shape;
for (int64_t x : shape_vec[idx.entry_id(nid, i)]) {
CHECK_LE(x, static_cast<int64_t>(std::numeric_limits<int>::max()));
shape.push_back(make_const(Int(32), x));
}
out_info.push_back(
placeholder(shape,
GetTVMType(dtype_vec[idx.entry_id(nid, i)])));
}
// get default
Array<Tensor> out = fcompute[inode.source->op()](
inode.source->attrs, inputs, out_info);
CHECK_EQ(out.size(), inode.source->num_outputs());
// schedule on root node, and use master's schedule
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index);
tensor_vec[eid] = out[index];
}
}
// Schedule on final output.
Array<Tensor> outputs;
Array<Tensor> all_args = inputs;
for (const IndexedGraph::NodeEntry& e : idx.outputs()) {
const tvm::Tensor& t = tensor_vec[idx.entry_id(e)];
CHECK(t.defined());
outputs.push_back(t);
all_args.push_back(t);
}
Schedule sch = fschedule[schedule_op_key](
schedule_op_attr, outputs, target);
std::shared_ptr<GraphFuncNode> gf = std::make_shared<GraphFuncNode>();
gf->target = target;
gf->func_name = GetUniqeName(readable_name.str());
gf->inputs = inputs;
gf->outputs = outputs;
static const PackedFunc& flower = GetPackedFunc("nnvm.compiler.lower");
gf->funcs = flower(sch, all_args, gf->func_name);
return GraphFunc(gf);
}
private:
// Get unique name
std::string GetUniqeName(std::string name) {
while (true) {
auto it = name_map_.find(name);
if (it == name_map_.end()) {
name_map_[name] = 1;
return name;
} else {
std::ostringstream os;
os << name << "_" << it->second;
++(it->second);
name = os.str();
}
}
return name;
}
// global mutex
std::mutex mutex_;
// the name map
std::unordered_map<std::string, int> name_map_;
// the compiler cache
std::unordered_map<GraphKey, GraphCacheEntry,
GraphKeyHash, GraphKeyEqual> cache_;
};
GraphFunc GraphLower(Graph graph,
const Array<tvm::Tensor>& inputs,
const std::string& target,
const Op* schedule_op_key,
const NodeAttrs& schedule_op_attr) {
return CompileEngine::Global()->Lower(
graph, inputs, target, schedule_op_key, schedule_op_attr);
}
// Expose cache to front end
TVM_REGISTER_GLOBAL("nnvm.compiler.ListCacheItems")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = CompileEngine::Global()->ListCacheItems();
});
TVM_REGISTER_GLOBAL("nnvm.compiler.ClearCache")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
CompileEngine::Global()->Clear();
});
// NOTE: this involves graph lookup and can be slow
TVM_REGISTER_GLOBAL("nnvm.compiler.GetCacheItem")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = CompileEngine::Global()->Find(args[0]);
});
TVM_REGISTER_GLOBAL("nnvm.compiler.SetCacheItem")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
CompileEngine::Global()->Set(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("nnvm.compiler.GraphKeyGetGraph")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = args[0].operator GraphKey()->graph;
});
TVM_REGISTER_GLOBAL("nnvm.compiler.MakeGraphKey")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = GraphKeyNode::make(args[0], args[1], args[2]);
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GraphFuncNode>([](const GraphFuncNode *op, IRPrinter *p) {
p->stream << "GraphFunc(name=" << op->func_name
<< ", addr=" << op << ")";
});
} // namespace compiler
} // namespace nnvm
/*!
* Copyright (c) 2017 by Contributors
* \file compile_engine.h
* \brief Internal engine to compile a subgraph fragment and cache compilation.
*/
#ifndef NNVM_COMPILER_COMPILE_ENGINE_H_
#define NNVM_COMPILER_COMPILE_ENGINE_H_
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/tuple.h>
#include <nnvm/pass.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/operation.h>
#include <tvm/lowered_func.h>
#include <string>
#include "./graph_hash.h"
namespace nnvm {
namespace compiler {
/*! \brief A TVM Node to represent compiled graph function */
struct GraphFuncNode : public tvm::Node {
/* \brief compiled target */
std::string target;
/*! \brief Function name */
std::string func_name;
/* \brief The inputs to the function */
tvm::Array<Tensor> inputs;
/* \brief The outputs to the function */
tvm::Array<Tensor> outputs;
/*! \brief The lowered functions */
tvm::Array<tvm::LoweredFunc> funcs;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("target", &target);
v->Visit("func_name", &func_name);
v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs);
v->Visit("funcs", &funcs);
}
static constexpr const char* _type_key = "GraphFunc";
TVM_DECLARE_NODE_TYPE_INFO(GraphFuncNode, tvm::Node);
};
TVM_DEFINE_NODE_REF(GraphFunc, GraphFuncNode);
/*! \brief Cache Entry in the graph */
struct GraphCacheEntryNode : public tvm::Node {
/*! \brief The graph function */
GraphFunc graph_func;
/*! \brief Usage statistics */
int use_count{0};
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("graph_func", &graph_func);
v->Visit("use_count", &use_count);
}
static constexpr const char* _type_key = "GraphCacheEntry";
TVM_DECLARE_NODE_TYPE_INFO(GraphCacheEntryNode, tvm::Node);
};
class GraphCacheEntry : public ::tvm::NodeRef {
public:
GraphCacheEntry() {}
explicit GraphCacheEntry(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {}
GraphCacheEntryNode* operator->() {
return static_cast<GraphCacheEntryNode*>(node_.get());
}
using ContainerType = GraphCacheEntryNode;
};
/*!
* \brief Call compile engine to lower a graph with given inputs.
*
* \param graph The graph to be compiled
* \param inputs The input specification.
* \param schedule_op_key The hint key for the schedule.
* \param schedule_op_attr The hint attribute for the schedule.
*
* \return func A lowered tvm function.
*/
GraphFunc GraphLower(Graph graph,
const Array<tvm::Tensor>& inputs,
const std::string& target,
const Op* schedule_op_key,
const NodeAttrs& schedule_op_attr);
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_COMPILE_ENGINE_H_
/*!
* Copyright (c) 2017 by Contributors
* \file graph_deep_compare.cc
* \brief Deep compare two graph structure
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/packed_func_ext.h>
#include <tvm/runtime/packed_func.h>
#include "./node_attr.h"
namespace nnvm {
namespace compiler {
// deep compare the graph structure
// not considering the graph attributes
// return non-empty error message if the graph mismatch.
// the comparator won't match name of intermediate node.
// compare_var_attr
std::string DeepCompare(Graph a, Graph b,
bool compare_variable_attr) {
const IndexedGraph& idxa = a.indexed_graph();
const IndexedGraph& idxb = b.indexed_graph();
std::ostringstream err;
if (idxa.num_nodes() != idxb.num_nodes()) {
err << "Number of nodes mismatch";
return err.str();
}
if (idxa.num_node_entries() != idxb.num_node_entries()) {
err << "Number of node entry mismatch";
return err.str();
}
if (idxa.outputs().size() != idxb.outputs().size()) {
err << "Number of outputs mismatch";
return err.str();
}
for (size_t i = 0; i < idxa.outputs().size(); ++i) {
if (idxa.outputs()[i].node_id != idxb.outputs()[i].node_id ||
idxa.outputs()[i].index != idxb.outputs()[i].index) {
err << "Output entry mismatch";
return err.str();
}
}
if (idxa.input_nodes().size() != idxb.input_nodes().size()) {
err << "Number of inputs mismatch";
return err.str();
}
for (uint32_t nid = 0; nid < idxa.num_nodes(); ++nid) {
const IndexedGraph::Node& anode = idxa[nid];
const IndexedGraph::Node& bnode = idxb[nid];
if (anode.source->op() != bnode.source->op()) {
err << "Node mismatch ";
return err.str();
}
if (anode.source->is_variable()) {
CHECK(bnode.source->is_variable());
if (!compare_variable_attr) continue;
}
AttrDict adict = GetAttrDict(anode.source->attrs);
AttrDict bdict = GetAttrDict(bnode.source->attrs);
auto fmatch = [&err, &anode](const AttrDict& adict, const AttrDict& bdict) {
for (const auto& kv : adict) {
auto it = bdict.find(kv.first);
if (it != bdict.end()) {
if (it->second != kv.second) {
err << "Node attr mismatch, op=" << anode.source->attrs.name
<< " attr_key=" << kv.first << " " << it->second
<< " v.s. " << kv.second;
return false;
}
} else {
err << "One attr_key=" << kv.first << " is missing in another "
<< "op=" << anode.source->attrs.name;
return false;
}
}
return true;
};
if (!fmatch(adict, bdict)) return err.str();
if (adict.size() != bdict.size()) {
CHECK(!fmatch(bdict, adict));
return err.str();
}
if (anode.inputs.size() != bnode.inputs.size()) {
err << "Node input mismatch, op=" << anode.source->attrs.name;
return err.str();
}
if (anode.control_deps.size() != bnode.control_deps.size()) {
err << "Node control_deps mistach, op=" << anode.source->attrs.name;
return err.str();
}
for (size_t i = 0; i < anode.inputs.size(); ++i) {
const IndexedGraph::NodeEntry& ae = anode.inputs[i];
const IndexedGraph::NodeEntry& be = bnode.inputs[i];
if (ae.node_id != be.node_id ||
ae.index != be.index ||
ae.version != be.version) {
err << "Node input mismatch on, op=" << anode.source->attrs.name;
return err.str();
}
}
for (size_t i = 0; i < anode.control_deps.size(); ++i) {
if (anode.control_deps[i] != bnode.control_deps[i]) {
err << "Node control_dep mismatch on, op=" << anode.source->attrs.name;
return err.str();
}
}
}
return "";
}
TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = DeepCompare(args[0], args[1], args[2]);
});
} // namespace compiler
} // namespace nnvm
/*!
* Copyright (c) 2017 by Contributors
* \file graph_deep_compare.cc
* \brief Deep compare two graph structure
*/
#include <dmlc/common.h>
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/packed_func_ext.h>
#include <tvm/ir.h>
#include <tvm/runtime/packed_func.h>
#include <functional>
#include "./node_attr.h"
#include "./graph_hash.h"
namespace nnvm {
namespace compiler {
using namespace tvm;
using tvm::ir::IntImm;
size_t HashPlaceHolder(const Tensor& t) {
size_t key = t->shape.size();
key = dmlc::HashCombine(key, (t->dtype.code() << 8) | t->dtype.bits());
for (Expr s : t->shape) {
if (const IntImm* op = s.as<IntImm>()) {
key = dmlc::HashCombine(key, op->value);
}
}
return key;
}
bool PlaceHolderEqual(const Tensor& a, const Tensor& b) {
if (a->shape.size() != b->shape.size()) return false;
if (a->dtype != b->dtype) return false;
for (size_t i = 0; i < a->shape.size(); ++i) {
const IntImm* a_value = a->shape[i].as<IntImm>();
const IntImm* b_value = b->shape[i].as<IntImm>();
if (a_value && b_value == nullptr) return false;
if (b_value && a_value == nullptr) return false;
if (a_value == nullptr && b_value == nullptr) {
continue;
}
if (a_value->value != b_value->value) return false;
}
return true;
}
size_t GraphKeyHash::Hash(const GraphKey& gkey) {
if (gkey->cache_hash_key_ != 0) return gkey->cache_hash_key_;
size_t key = dmlc::HashCombine(GraphHash(gkey->graph), gkey->target);
key = dmlc::HashCombine(key, gkey->inputs.size());
for (size_t i = 0; i < gkey->inputs.size(); ++i) {
key = dmlc::HashCombine(key, HashPlaceHolder(gkey->inputs[i]));
}
if (key == 0) key = 1;
gkey->cache_hash_key_ = key;
return key;
}
bool GraphKeyEqual::Equal(const GraphKey& a,
const GraphKey& b) {
if (a->target != b->target) return false;
if (a->inputs.size() != b->inputs.size()) return false;
for (size_t i = 0; i < a->inputs.size(); ++i) {
if (!PlaceHolderEqual(a->inputs[i], b->inputs[i])) return false;
}
if (GraphDeepCompare(a->graph, b->graph, false).length() != 0) return false;
return true;
}
GraphKey GraphKeyNode::make(Graph graph,
tvm::Array<Tensor> inputs,
std::string target) {
std::shared_ptr<GraphKeyNode> n
= std::make_shared<GraphKeyNode>();
n->graph = std::move(graph);
n->inputs = inputs;
n->target = std::move(target);
return GraphKey(n);
}
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GraphKeyNode>([](const GraphKeyNode *op, IRPrinter *p) {
p->stream << "GraphKeyNode("<< op << ")";
});
// Run graph hash
size_t GraphHash(const Graph& graph) {
const IndexedGraph& idx = graph.indexed_graph();
size_t key = 0;
// Combine a linearized sequence of ops in subgraph
key = dmlc::HashCombine(key, idx.num_nodes());
std::hash<std::string> str_hash;
std::vector<size_t> hash_temp;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const IndexedGraph::Node& inode = idx[nid];
// Use name instad op address so it is deterministic across runs
key = dmlc::HashCombine(key, inode.source->op()p);
if (inode.source->is_variable()) continue;
hash_temp.clear();
for (const auto& kv : GetAttrDict(inode.source->attrs)) {
hash_temp.push_back(dmlc::HashCombine(str_hash(kv.first), kv.second));
}
// to make sure it is deterministic
// since unordered_map is not deterministic
std::sort(hash_temp.begin(), hash_temp.end());
for (size_t value : hash_temp) {
key = dmlc::HashCombine(key, value);
}
}
return key;
}
// deep compare the graph structure
// not considering the graph attributes
// return non-empty error message if the graph mismatch.
// the comparator won't match name of intermediate node.
// compare_var_attr
std::string GraphDeepCompare(const Graph& a,
const Graph& b,
bool compare_variable_attr) {
const IndexedGraph& idxa = a.indexed_graph();
const IndexedGraph& idxb = b.indexed_graph();
std::ostringstream err;
if (idxa.num_nodes() != idxb.num_nodes()) {
err << "Number of nodes mismatch";
return err.str();
}
if (idxa.num_node_entries() != idxb.num_node_entries()) {
err << "Number of node entry mismatch";
return err.str();
}
if (idxa.outputs().size() != idxb.outputs().size()) {
err << "Number of outputs mismatch";
return err.str();
}
for (size_t i = 0; i < idxa.outputs().size(); ++i) {
if (idxa.outputs()[i].node_id != idxb.outputs()[i].node_id ||
idxa.outputs()[i].index != idxb.outputs()[i].index) {
err << "Output entry mismatch";
return err.str();
}
}
if (idxa.input_nodes().size() != idxb.input_nodes().size()) {
err << "Number of inputs mismatch";
return err.str();
}
for (uint32_t nid = 0; nid < idxa.num_nodes(); ++nid) {
const IndexedGraph::Node& anode = idxa[nid];
const IndexedGraph::Node& bnode = idxb[nid];
if (anode.source->op() != bnode.source->op()) {
err << "Node mismatch ";
return err.str();
}
if (anode.source->is_variable()) {
CHECK(bnode.source->is_variable());
if (!compare_variable_attr) continue;
}
AttrDict adict = GetAttrDict(anode.source->attrs);
AttrDict bdict = GetAttrDict(bnode.source->attrs);
auto fmatch = [&err, &anode](const AttrDict& adict, const AttrDict& bdict) {
for (const auto& kv : adict) {
auto it = bdict.find(kv.first);
if (it != bdict.end()) {
if (it->second != kv.second) {
err << "Node attr mismatch, op=" << anode.source->attrs.name
<< " attr_key=" << kv.first << " " << it->second
<< " v.s. " << kv.second;
return false;
}
} else {
err << "One attr_key=" << kv.first << " is missing in another "
<< "op=" << anode.source->attrs.name;
return false;
}
}
return true;
};
if (!fmatch(adict, bdict)) return err.str();
if (adict.size() != bdict.size()) {
CHECK(!fmatch(bdict, adict));
return err.str();
}
if (anode.inputs.size() != bnode.inputs.size()) {
err << "Node input mismatch, op=" << anode.source->attrs.name;
return err.str();
}
if (anode.control_deps.size() != bnode.control_deps.size()) {
err << "Node control_deps mistach, op=" << anode.source->attrs.name;
return err.str();
}
for (size_t i = 0; i < anode.inputs.size(); ++i) {
const IndexedGraph::NodeEntry& ae = anode.inputs[i];
const IndexedGraph::NodeEntry& be = bnode.inputs[i];
if (ae.node_id != be.node_id ||
ae.index != be.index ||
ae.version != be.version) {
err << "Node input mismatch on, op=" << anode.source->attrs.name;
return err.str();
}
}
for (size_t i = 0; i < anode.control_deps.size(); ++i) {
if (anode.control_deps[i] != bnode.control_deps[i]) {
err << "Node control_dep mismatch on, op=" << anode.source->attrs.name;
return err.str();
}
}
}
return "";
}
TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = GraphDeepCompare(args[0], args[1], args[2]);
});
} // namespace compiler
} // namespace nnvm
/*!
* Copyright (c) 2017 by Contributors
* \file graph_hash.h
* \brief The graph hashing function.
*/
#ifndef NNVM_COMPILER_GRAPH_HASH_H_
#define NNVM_COMPILER_GRAPH_HASH_H_
#include <dmlc/common.h>
#include <nnvm/graph.h>
#include <tvm/operation.h>
#include <string>
namespace nnvm {
namespace compiler {
class GraphKey;
/*! \brief Key to a graph compiler cache */
struct GraphKeyNode : public tvm::Node {
/*! \brief The graph structure */
Graph graph;
/* \brief The inputs to the function */
tvm::Array<Tensor> inputs;
/*! \brief The target */
std::string target;
// Cached internal hash key, invisible to the user.
// The graph hash key is ensured always not to be 0
mutable size_t cache_hash_key_{0};
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("inputs", &inputs);
v->Visit("target", &target);
}
static GraphKey make(Graph graph,
tvm::Array<Tensor> inputs,
std::string target);
static constexpr const char* _type_key = "GraphKey";
TVM_DECLARE_NODE_TYPE_INFO(GraphKeyNode, tvm::Node);
};
TVM_DEFINE_NODE_REF(GraphKey, GraphKeyNode);
/*! \brief Hashing function for graph key */
struct GraphKeyHash {
size_t operator()(const GraphKey& gkey) const {
return Hash(gkey);
}
static size_t Hash(const GraphKey& gkey);
};
/*! \brief function for graph key */
struct GraphKeyEqual {
bool operator()(const GraphKey& a,
const GraphKey& b) const {
return Equal(a, b);
}
static bool Equal(const GraphKey& a, const GraphKey& b);
};
/*!
* \brief Create a hash code for a given graph.
* \return The hash code of the graph.
*/
size_t GraphHash(const Graph& graph);
/*!
* \brief Compare two graphs
* return empty string if they are equal
* otherwise return error message
* \param a The first graph.
* \param b The second graph.
* \return empty string if they are equal, otherwise return error message.
*/
std::string GraphDeepCompare(const Graph& a,
const Graph& b,
bool compare_variable_attr);
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_GRAPH_HASH_H_
......@@ -9,7 +9,7 @@
namespace nnvm {
const IndexedGraph& Graph::indexed_graph() {
const IndexedGraph& Graph::indexed_graph() const {
if (indexed_graph_ == nullptr) {
indexed_graph_.reset(new IndexedGraph(*this));
}
......
......@@ -180,7 +180,7 @@ void PrintGraphIR_(Graph src,
}
// save a graph to json
Graph PrintGraphIR(Graph src) {
Graph PrintGraphIRPass(Graph src) {
std::ostringstream os;
std::vector<std::string> join_entry_attrs, join_node_attrs;
if (src.attrs.count("join_entry_attrs") != 0) {
......@@ -200,7 +200,7 @@ Graph PrintGraphIR(Graph src) {
// register pass
NNVM_REGISTER_PASS(PrintGraphIR)
.describe("Return a empty Graph, save ir to ret.attrs[\"graphir\"]")
.set_body(PrintGraphIR);
.set_body(PrintGraphIRPass);
} // namespace pass
} // namespace nnvm
import numpy as np
import tvm
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
def test_compile_cache():
x = sym.Variable("x")
y = sym.Variable("y")
z = sym.exp(y + x)
shape = (10, 1)
dtype = tvm.float32
shape_dict = {"x": shape, "y": shape}
def verify(graph, lib):
m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
# get member functions
na = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
nb = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
m.run(x=na, y=nb)
# get outputs
out = m.get_output(0, tvm.nd.empty(shape, dtype))
np.testing.assert_allclose(
out.asnumpy(), np.exp(na.asnumpy() + nb.asnumpy()))
engine = nnvm.compiler.engine
graph, lib, _ = nnvm.compiler.build(z, "llvm", shape_dict)
inputs = [tvm.placeholder((10,)), tvm.placeholder((10,))]
gkey = nnvm.compiler.graph_key(nnvm.graph.create(z), inputs, "llvm")
gkey2 = nnvm.compiler.graph_key(nnvm.graph.create(z), inputs + inputs, "llvm")
gf = engine[gkey]
assert gf is not None
assert engine[gkey2] is None
graph, lib, _ = nnvm.compiler.build(z, "llvm", shape_dict)
assert graph.index.num_nodes == 3
verify(graph, lib)
# Test various set external cache
engine.clear_cache()
engine[gkey] = gf
if __name__ == "__main__":
test_compile_cache()
......@@ -4,7 +4,7 @@ import tvm
import topi
from nnvm import symbol as sym
from nnvm.compiler import graph_util, graph_attr
from nnvm.testing.config import test_ctx_list
from nnvm.testing import ctx_list
def test_ewise_injective():
x = sym.Variable("x")
......@@ -14,7 +14,7 @@ def test_ewise_injective():
shape_dict = {"x": dshape}
dtype = "float32"
target = "llvm"
for target, ctx in test_ctx_list():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
assert graph.index.num_nodes == 2
m = nnvm.runtime.create(graph, lib, ctx)
......@@ -37,7 +37,7 @@ def test_conv_ewise_injective():
oshape = (1, 32* 18 * 18)
shape_dict = {"x": dshape}
for target, ctx in test_ctx_list():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx)
# print(graph.ir(join_entry_attrs=["shape"]))
......@@ -64,7 +64,7 @@ def test_injective_reduce_injective():
dshape = (32, 1, 18, 18)
shape_dict = {"x": dshape}
for target, ctx in test_ctx_list():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx)
assert graph.index.num_nodes == 2
......
......@@ -4,7 +4,7 @@ import topi
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
from nnvm.testing.config import test_ctx_list
from nnvm.testing.config import ctx_list
def test_relu():
x = sym.Variable("x")
......@@ -13,7 +13,7 @@ def test_relu():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
for target, ctx in test_ctx_list():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
......@@ -31,7 +31,7 @@ def test_exp():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
for target, ctx in test_ctx_list():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
......@@ -54,7 +54,7 @@ def test_log():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
for target, ctx in test_ctx_list():
for target, ctx in ctx_list():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
......@@ -78,7 +78,7 @@ def test_tanh():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
for target, ctx in test_ctx_list():
for target, ctx in ctx_list():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
......@@ -102,7 +102,7 @@ def test_sigmoid():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
for target, ctx in test_ctx_list():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
......@@ -125,7 +125,7 @@ def test_softmax():
dtype = "float32"
dshape = (10, 1000)
oshape = dshape
for target, ctx in test_ctx_list():
for target, ctx in ctx_list():
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
......@@ -153,7 +153,7 @@ def test_dense():
"dense_weight" : (3, 100),
"dense_bias" : (3,),
}
for target, ctx in test_ctx_list():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape)
m = nnvm.runtime.create(graph, lib, ctx)
x_np = np.random.uniform(size=shape["x"]).astype(dtype)
......@@ -179,7 +179,7 @@ def test_batchnorm():
y = sym.batch_norm(
x, gamma, beta, moving_mean, moving_var, epsilon=eps)
for target, ctx in test_ctx_list():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, "llvm", {"x": shape})
m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
x_np = np.random.uniform(size=shape).astype(dtype)
......
......@@ -5,7 +5,7 @@ import topi
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
from nnvm.testing.config import test_ctx_list
from nnvm.testing.config import ctx_list
def test_conv2d():
......@@ -17,7 +17,7 @@ def test_conv2d():
kshape = (10, 3, 3, 3)
oshape = (1, 10, 18, 18)
shape_dict = {"x": dshape}
for target, ctx in test_ctx_list():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx)
# get member functions
......@@ -46,7 +46,7 @@ def test_grouped_conv2d():
kshape = (32, 1, 3, 3)
oshape = (1, 32, 18, 18)
shape_dict = {"x": dshape}
for target, ctx in test_ctx_list():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = nnvm.runtime.create(graph, lib, ctx)
# set input
......
......@@ -4,7 +4,7 @@ import topi
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
from nnvm.testing.config import test_ctx_list
from nnvm.testing.config import ctx_list
def verify_transpose(dshape, axes):
x = sym.Variable("x")
......@@ -14,7 +14,7 @@ def verify_transpose(dshape, axes):
y = sym.transpose(x)
y = y + 1
dtype = "float32"
for target, ctx in test_ctx_list():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# set input
......@@ -29,7 +29,7 @@ def verify_reduce(dshape, fnp, fsym, **kwargs):
x = sym.Variable("x")
y = fsym(x + 1, **kwargs)
dtype = "float32"
for target, ctx in test_ctx_list():
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = nnvm.runtime.create(graph, lib, ctx)
# set input
......@@ -54,3 +54,4 @@ def test_reduce():
if __name__ == "__main__":
test_reduce()
test_tranpose()
print(nnvm.compiler.engine.dump())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment