Commit 6c198621 by Lianmin Zheng Committed by Tianqi Chen

Add CacheItem2Schedule Extension (#338)

* add CacheItem2Schedule extension

* fix lint

* move function position

* make cache item visible to frontend
parent 2bb4a1e7
...@@ -82,8 +82,7 @@ class CompileEngine { ...@@ -82,8 +82,7 @@ class CompileEngine {
GraphFunc Lower(Graph graph, GraphFunc Lower(Graph graph,
const Array<tvm::Tensor>& inputs, const Array<tvm::Tensor>& inputs,
const std::string& target, const std::string& target,
const Op* schedule_op_key, int master_idx) {
const NodeAttrs& schedule_op_attr) {
GraphKey key = GraphKeyNode::make(graph, inputs, target); GraphKey key = GraphKeyNode::make(graph, inputs, target);
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto it = cache_.find(key); auto it = cache_.find(key);
...@@ -91,11 +90,11 @@ class CompileEngine { ...@@ -91,11 +90,11 @@ class CompileEngine {
++(it->second->use_count); ++(it->second->use_count);
return it->second->graph_func; return it->second->graph_func;
} }
GraphFunc f = DoLower(key->graph, key->inputs, key->target, GraphFunc f = DoLower(key->graph, key->inputs, key->target, master_idx);
schedule_op_key, schedule_op_attr);
std::shared_ptr<GraphCacheEntryNode> n = std::make_shared<GraphCacheEntryNode>(); std::shared_ptr<GraphCacheEntryNode> n = std::make_shared<GraphCacheEntryNode>();
n->graph_func = f; n->graph_func = f;
n->use_count = 1; n->use_count = 1;
n->master_idx = master_idx;
cache_[key] = GraphCacheEntry(n); cache_[key] = GraphCacheEntry(n);
return f; return f;
} }
...@@ -134,12 +133,14 @@ class CompileEngine { ...@@ -134,12 +133,14 @@ class CompileEngine {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
cache_.clear(); cache_.clear();
} }
// run the actual lowering process
GraphFunc DoLower(Graph graph, // get schedule and its args
const Array<tvm::Tensor>& inputs, std::pair<Schedule, Array<tvm::Tensor> > GetScheduleArgs(Graph graph,
const std::string& target, const Array<tvm::Tensor> &inputs,
const Op* schedule_op_key, const std::string &target,
const NodeAttrs& schedule_op_attr) { int master_idx,
std::string *readable_name,
Array<tvm::Tensor> *outputs) {
// shape, type // shape, type
static auto& fcompute = static auto& fcompute =
nnvm::Op::GetAttr<FTVMCompute>("FTVMCompute"); nnvm::Op::GetAttr<FTVMCompute>("FTVMCompute");
...@@ -172,18 +173,18 @@ class CompileEngine { ...@@ -172,18 +173,18 @@ class CompileEngine {
tensor_vec[idx.entry_id(nid, 0)] = inputs[i]; tensor_vec[idx.entry_id(nid, 0)] = inputs[i];
} }
std::ostringstream readable_name; std::ostringstream readable_name_os;
readable_name << "fuse"; readable_name_os << "fuse";
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid]; const auto& inode = idx[nid];
if (inode.source->is_variable()) continue; if (inode.source->is_variable()) continue;
Array<Tensor> inputs, out_info; Array<Tensor> op_inputs, out_info;
readable_name << "_" << inode.source->op()->name; readable_name_os << "_" << inode.source->op()->name;
// input array // input array
for (const IndexedGraph::NodeEntry& e : inode.inputs) { for (const IndexedGraph::NodeEntry& e : inode.inputs) {
const tvm::Tensor& t = tensor_vec[idx.entry_id(e)]; const tvm::Tensor& t = tensor_vec[idx.entry_id(e)];
CHECK(t.defined()); CHECK(t.defined());
inputs.push_back(t); op_inputs.push_back(t);
} }
// output hint // output hint
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) { for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
...@@ -198,7 +199,7 @@ class CompileEngine { ...@@ -198,7 +199,7 @@ class CompileEngine {
} }
// get default // get default
Array<Tensor> out = fcompute[inode.source->op()]( Array<Tensor> out = fcompute[inode.source->op()](
inode.source->attrs, inputs, out_info); inode.source->attrs, op_inputs, out_info);
CHECK_EQ(out.size(), inode.source->num_outputs()); CHECK_EQ(out.size(), inode.source->num_outputs());
// schedule on root node, and use master's schedule // schedule on root node, and use master's schedule
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
...@@ -207,19 +208,43 @@ class CompileEngine { ...@@ -207,19 +208,43 @@ class CompileEngine {
} }
} }
// Schedule on final output. // Schedule on final output.
Array<Tensor> outputs;
Array<Tensor> all_args = inputs; Array<Tensor> all_args = inputs;
Array<Tensor> outs;
for (const IndexedGraph::NodeEntry& e : idx.outputs()) { for (const IndexedGraph::NodeEntry& e : idx.outputs()) {
const tvm::Tensor& t = tensor_vec[idx.entry_id(e)]; const tvm::Tensor& t = tensor_vec[idx.entry_id(e)];
CHECK(t.defined()); CHECK(t.defined());
outputs.push_back(t); outs.push_back(t);
all_args.push_back(t); all_args.push_back(t);
} }
Schedule sch = fschedule[schedule_op_key](
schedule_op_attr, outputs, target); Schedule sch = fschedule[idx[master_idx].source->op()](
idx[master_idx].source->attrs, outs, target);
// store extra return values
if (readable_name != nullptr)
*readable_name = readable_name_os.str();
if (outputs != nullptr)
*outputs = outs;
return std::make_pair(sch, all_args);
}
// run the actual lowering process
GraphFunc DoLower(Graph graph,
const Array<tvm::Tensor>& inputs,
const std::string& target,
int master_idx) {
std::string readable_name;
Array<tvm::Tensor> all_args;
Array<tvm::Tensor> outputs;
Schedule sch;
std::tie(sch, all_args) = GetScheduleArgs(graph, inputs, target, master_idx,
&readable_name, &outputs);
std::shared_ptr<GraphFuncNode> gf = std::make_shared<GraphFuncNode>(); std::shared_ptr<GraphFuncNode> gf = std::make_shared<GraphFuncNode>();
gf->target = target; gf->target = target;
gf->func_name = GetUniqeName(readable_name.str()); gf->func_name = GetUniqeName(readable_name);
gf->inputs = inputs; gf->inputs = inputs;
gf->outputs = outputs; gf->outputs = outputs;
static const PackedFunc& flower = GetPackedFunc("nnvm.compiler.lower"); static const PackedFunc& flower = GetPackedFunc("nnvm.compiler.lower");
...@@ -257,10 +282,9 @@ class CompileEngine { ...@@ -257,10 +282,9 @@ class CompileEngine {
GraphFunc GraphLower(Graph graph, GraphFunc GraphLower(Graph graph,
const Array<tvm::Tensor>& inputs, const Array<tvm::Tensor>& inputs,
const std::string& target, const std::string& target,
const Op* schedule_op_key, int master_idx) {
const NodeAttrs& schedule_op_attr) {
return CompileEngine::Global()->Lower( return CompileEngine::Global()->Lower(
graph, inputs, target, schedule_op_key, schedule_op_attr); graph, inputs, target, master_idx);
} }
// Expose cache to front end // Expose cache to front end
...@@ -295,6 +319,30 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.MakeGraphKey") ...@@ -295,6 +319,30 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.MakeGraphKey")
*rv = GraphKeyNode::make(args[0], args[1], args[2]); *rv = GraphKeyNode::make(args[0], args[1], args[2]);
}); });
// This can be used to extract workloads from nnvm compiler
TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs")
.set_body([](TVMArgs args, TVMRetValue *rv) {
Array<tvm::NodeRef> item = args[0];
const GraphKeyNode *key = reinterpret_cast<const GraphKeyNode *>(item[0].get());
const GraphCacheEntryNode *value = reinterpret_cast<const GraphCacheEntryNode *>(item[1].get());
// extract arguments from cached item
Graph graph = key->graph;
const Array<tvm::Tensor> &inputs = key->inputs;
std::string target = args[1];
int master_idx = value->master_idx;
Schedule sch;
Array<tvm::Tensor> all_args;
std::tie(sch, all_args) = CompileEngine::Global()->GetScheduleArgs(
graph, inputs, target, master_idx, nullptr, nullptr);
Array<tvm::NodeRef> ret;
ret.push_back(sch);
ret.push_back(all_args);
*rv = ret;
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GraphFuncNode>([](const GraphFuncNode *op, IRPrinter *p) { .set_dispatch<GraphFuncNode>([](const GraphFuncNode *op, IRPrinter *p) {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/lowered_func.h> #include <tvm/lowered_func.h>
#include <string> #include <string>
#include <utility>
#include "./graph_hash.h" #include "./graph_hash.h"
namespace nnvm { namespace nnvm {
...@@ -55,10 +56,13 @@ struct GraphCacheEntryNode : public tvm::Node { ...@@ -55,10 +56,13 @@ struct GraphCacheEntryNode : public tvm::Node {
GraphFunc graph_func; GraphFunc graph_func;
/*! \brief Usage statistics */ /*! \brief Usage statistics */
int use_count{0}; int use_count{0};
/*! \brief Index of the master node for calling schedule*/
int master_idx;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("graph_func", &graph_func); v->Visit("graph_func", &graph_func);
v->Visit("use_count", &use_count); v->Visit("use_count", &use_count);
v->Visit("master_idx", &master_idx);
} }
static constexpr const char* _type_key = "GraphCacheEntry"; static constexpr const char* _type_key = "GraphCacheEntry";
TVM_DECLARE_NODE_TYPE_INFO(GraphCacheEntryNode, tvm::Node); TVM_DECLARE_NODE_TYPE_INFO(GraphCacheEntryNode, tvm::Node);
...@@ -79,16 +83,15 @@ class GraphCacheEntry : public ::tvm::NodeRef { ...@@ -79,16 +83,15 @@ class GraphCacheEntry : public ::tvm::NodeRef {
* *
* \param graph The graph to be compiled * \param graph The graph to be compiled
* \param inputs The input specification. * \param inputs The input specification.
* \param schedule_op_key The hint key for the schedule. * \param target The build target
* \param schedule_op_attr The hint attribute for the schedule. * \param master_idx The index of master node for calling schedule
* *
* \return func A lowered tvm function. * \return func A lowered tvm function.
*/ */
GraphFunc GraphLower(Graph graph, GraphFunc GraphLower(Graph graph,
const Array<tvm::Tensor>& inputs, const Array<tvm::Tensor>& inputs,
const std::string& target, const std::string& target,
const Op* schedule_op_key, int master_idx);
const NodeAttrs& schedule_op_attr);
/*! /*!
* \brief Get type flag from TVM Type * \brief Get type flag from TVM Type
......
...@@ -315,9 +315,15 @@ nnvm::Graph GraphFuseCompile(nnvm::Graph g) { ...@@ -315,9 +315,15 @@ nnvm::Graph GraphFuseCompile(nnvm::Graph g) {
auto it = fe.input_info.find(subidx[sub_input_id].source); auto it = fe.input_info.find(subidx[sub_input_id].source);
inputs.push_back(it->second); inputs.push_back(it->second);
} }
fe.compiled_func = GraphLower(fe.subgraph, inputs, target, // find master idx in subgraph
idx[master].source->op(), int sub_master_idx = 0;
idx[master].source->attrs); for (uint32_t i = 0; i < subidx.num_nodes(); i++) {
if (subidx[i].source->op() == idx[master].source->op()) {
sub_master_idx = i;
break;
}
}
fe.compiled_func = GraphLower(fe.subgraph, inputs, target, sub_master_idx);
for (LoweredFunc f : fe.compiled_func->funcs) { for (LoweredFunc f : fe.compiled_func->funcs) {
if (!func_set.count(f.get())) { if (!func_set.count(f.get())) {
func_set.insert(f.get()); func_set.insert(f.get());
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <nnvm/compiler/packed_func_ext.h> #include <nnvm/compiler/packed_func_ext.h>
#include <nnvm/compiler/op_attr_types.h> #include <nnvm/compiler/op_attr_types.h>
#include "./node_attr.h" #include "./node_attr.h"
#include "compile_engine.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment