Commit 6c198621 by Lianmin Zheng Committed by Tianqi Chen

Add CacheItem2Schedule Extension (#338)

* add CacheItem2Schedule extension

* fix lint

* move function position

* make cache item visible to frontend
parent 2bb4a1e7
......@@ -82,8 +82,7 @@ class CompileEngine {
GraphFunc Lower(Graph graph,
const Array<tvm::Tensor>& inputs,
const std::string& target,
const Op* schedule_op_key,
const NodeAttrs& schedule_op_attr) {
int master_idx) {
GraphKey key = GraphKeyNode::make(graph, inputs, target);
std::lock_guard<std::mutex> lock(mutex_);
auto it = cache_.find(key);
......@@ -91,11 +90,11 @@ class CompileEngine {
++(it->second->use_count);
return it->second->graph_func;
}
GraphFunc f = DoLower(key->graph, key->inputs, key->target,
schedule_op_key, schedule_op_attr);
GraphFunc f = DoLower(key->graph, key->inputs, key->target, master_idx);
std::shared_ptr<GraphCacheEntryNode> n = std::make_shared<GraphCacheEntryNode>();
n->graph_func = f;
n->use_count = 1;
n->master_idx = master_idx;
cache_[key] = GraphCacheEntry(n);
return f;
}
......@@ -134,12 +133,14 @@ class CompileEngine {
std::lock_guard<std::mutex> lock(mutex_);
cache_.clear();
}
// run the actual lowering process
GraphFunc DoLower(Graph graph,
const Array<tvm::Tensor>& inputs,
const std::string& target,
const Op* schedule_op_key,
const NodeAttrs& schedule_op_attr) {
// get schedule and its args
std::pair<Schedule, Array<tvm::Tensor> > GetScheduleArgs(Graph graph,
const Array<tvm::Tensor> &inputs,
const std::string &target,
int master_idx,
std::string *readable_name,
Array<tvm::Tensor> *outputs) {
// shape, type
static auto& fcompute =
nnvm::Op::GetAttr<FTVMCompute>("FTVMCompute");
......@@ -172,18 +173,18 @@ class CompileEngine {
tensor_vec[idx.entry_id(nid, 0)] = inputs[i];
}
std::ostringstream readable_name;
readable_name << "fuse";
std::ostringstream readable_name_os;
readable_name_os << "fuse";
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
Array<Tensor> inputs, out_info;
readable_name << "_" << inode.source->op()->name;
Array<Tensor> op_inputs, out_info;
readable_name_os << "_" << inode.source->op()->name;
// input array
for (const IndexedGraph::NodeEntry& e : inode.inputs) {
const tvm::Tensor& t = tensor_vec[idx.entry_id(e)];
CHECK(t.defined());
inputs.push_back(t);
op_inputs.push_back(t);
}
// output hint
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
......@@ -198,7 +199,7 @@ class CompileEngine {
}
// get default
Array<Tensor> out = fcompute[inode.source->op()](
inode.source->attrs, inputs, out_info);
inode.source->attrs, op_inputs, out_info);
CHECK_EQ(out.size(), inode.source->num_outputs());
// schedule on root node, and use master's schedule
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
......@@ -207,19 +208,43 @@ class CompileEngine {
}
}
// Schedule on final output.
Array<Tensor> outputs;
Array<Tensor> all_args = inputs;
Array<Tensor> outs;
for (const IndexedGraph::NodeEntry& e : idx.outputs()) {
const tvm::Tensor& t = tensor_vec[idx.entry_id(e)];
CHECK(t.defined());
outputs.push_back(t);
outs.push_back(t);
all_args.push_back(t);
}
Schedule sch = fschedule[schedule_op_key](
schedule_op_attr, outputs, target);
Schedule sch = fschedule[idx[master_idx].source->op()](
idx[master_idx].source->attrs, outs, target);
// store extra return values
if (readable_name != nullptr)
*readable_name = readable_name_os.str();
if (outputs != nullptr)
*outputs = outs;
return std::make_pair(sch, all_args);
}
// run the actual lowering process
GraphFunc DoLower(Graph graph,
const Array<tvm::Tensor>& inputs,
const std::string& target,
int master_idx) {
std::string readable_name;
Array<tvm::Tensor> all_args;
Array<tvm::Tensor> outputs;
Schedule sch;
std::tie(sch, all_args) = GetScheduleArgs(graph, inputs, target, master_idx,
&readable_name, &outputs);
std::shared_ptr<GraphFuncNode> gf = std::make_shared<GraphFuncNode>();
gf->target = target;
gf->func_name = GetUniqeName(readable_name.str());
gf->func_name = GetUniqeName(readable_name);
gf->inputs = inputs;
gf->outputs = outputs;
static const PackedFunc& flower = GetPackedFunc("nnvm.compiler.lower");
......@@ -257,10 +282,9 @@ class CompileEngine {
GraphFunc GraphLower(Graph graph,
const Array<tvm::Tensor>& inputs,
const std::string& target,
const Op* schedule_op_key,
const NodeAttrs& schedule_op_attr) {
int master_idx) {
return CompileEngine::Global()->Lower(
graph, inputs, target, schedule_op_key, schedule_op_attr);
graph, inputs, target, master_idx);
}
// Expose cache to front end
......@@ -295,6 +319,30 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.MakeGraphKey")
*rv = GraphKeyNode::make(args[0], args[1], args[2]);
});
// This can be used to extract workloads from nnvm compiler
TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs")
.set_body([](TVMArgs args, TVMRetValue *rv) {
Array<tvm::NodeRef> item = args[0];
const GraphKeyNode *key = reinterpret_cast<const GraphKeyNode *>(item[0].get());
const GraphCacheEntryNode *value = reinterpret_cast<const GraphCacheEntryNode *>(item[1].get());
// extract arguments from cached item
Graph graph = key->graph;
const Array<tvm::Tensor> &inputs = key->inputs;
std::string target = args[1];
int master_idx = value->master_idx;
Schedule sch;
Array<tvm::Tensor> all_args;
std::tie(sch, all_args) = CompileEngine::Global()->GetScheduleArgs(
graph, inputs, target, master_idx, nullptr, nullptr);
Array<tvm::NodeRef> ret;
ret.push_back(sch);
ret.push_back(all_args);
*rv = ret;
});
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GraphFuncNode>([](const GraphFuncNode *op, IRPrinter *p) {
......
......@@ -17,6 +17,7 @@
#include <tvm/operation.h>
#include <tvm/lowered_func.h>
#include <string>
#include <utility>
#include "./graph_hash.h"
namespace nnvm {
......@@ -55,10 +56,13 @@ struct GraphCacheEntryNode : public tvm::Node {
GraphFunc graph_func;
/*! \brief Usage statistics */
int use_count{0};
/*! \brief Index of the master node for calling schedule*/
int master_idx;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("graph_func", &graph_func);
v->Visit("use_count", &use_count);
v->Visit("master_idx", &master_idx);
}
static constexpr const char* _type_key = "GraphCacheEntry";
TVM_DECLARE_NODE_TYPE_INFO(GraphCacheEntryNode, tvm::Node);
......@@ -79,16 +83,15 @@ class GraphCacheEntry : public ::tvm::NodeRef {
*
* \param graph The graph to be compiled
* \param inputs The input specification.
* \param schedule_op_key The hint key for the schedule.
* \param schedule_op_attr The hint attribute for the schedule.
* \param target The build target
* \param master_idx The index of master node for calling schedule
*
* \return func A lowered tvm function.
*/
GraphFunc GraphLower(Graph graph,
const Array<tvm::Tensor>& inputs,
const std::string& target,
const Op* schedule_op_key,
const NodeAttrs& schedule_op_attr);
int master_idx);
/*!
* \brief Get type flag from TVM Type
......
......@@ -315,9 +315,15 @@ nnvm::Graph GraphFuseCompile(nnvm::Graph g) {
auto it = fe.input_info.find(subidx[sub_input_id].source);
inputs.push_back(it->second);
}
fe.compiled_func = GraphLower(fe.subgraph, inputs, target,
idx[master].source->op(),
idx[master].source->attrs);
// find master idx in subgraph
int sub_master_idx = 0;
for (uint32_t i = 0; i < subidx.num_nodes(); i++) {
if (subidx[i].source->op() == idx[master].source->op()) {
sub_master_idx = i;
break;
}
}
fe.compiled_func = GraphLower(fe.subgraph, inputs, target, sub_master_idx);
for (LoweredFunc f : fe.compiled_func->funcs) {
if (!func_set.count(f.get())) {
func_set.insert(f.get());
......
......@@ -9,6 +9,7 @@
#include <nnvm/compiler/packed_func_ext.h>
#include <nnvm/compiler/op_attr_types.h>
#include "./node_attr.h"
#include "compile_engine.h"
namespace tvm {
namespace runtime {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment