Unverified Commit c113712d by Tianqi Chen Committed by GitHub

[RELAY][BACKEND] Enable PlanMemory in the graph runtime. (#2120)

parent 6edb3564
...@@ -458,12 +458,14 @@ inline const TTypeNode* ExprNode::type_as() const { ...@@ -458,12 +458,14 @@ inline const TTypeNode* ExprNode::type_as() const {
/*! /*!
* \brief Print node as text format. * \brief Print node as text format.
* \param node The node to be printed. * \param node The node to be printed.
* \param show_meta_data Whether to print meta data section.
* \param annotate An optional callback function for attaching * \param annotate An optional callback function for attaching
* additional comment block to an expr. * additional comment block to an expr.
* \return The text representation. * \return The text representation.
*/ */
std::string RelayPrint( std::string RelayPrint(
const NodeRef& node, const NodeRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr); runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -55,6 +55,7 @@ def build(funcs, target, target_host=None): ...@@ -55,6 +55,7 @@ def build(funcs, target, target_host=None):
funcs : List[tvm.LoweredFunc] funcs : List[tvm.LoweredFunc]
The list of lowered functions. The list of lowered functions.
target : tvm.Target target : tvm.Target
The target to run the code on. The target to run the code on.
......
...@@ -21,6 +21,7 @@ contrib.graph_runtime or any other TVM runtime comptatible system. ...@@ -21,6 +21,7 @@ contrib.graph_runtime or any other TVM runtime comptatible system.
from __future__ import absolute_import from __future__ import absolute_import
import json import json
import attr import attr
from . import _backend
from . import compile_engine from . import compile_engine
from ..op import Op from ..op import Op
from ..expr import Function, GlobalVar, ExprFunctor from ..expr import Function, GlobalVar, ExprFunctor
...@@ -103,11 +104,12 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -103,11 +104,12 @@ class GraphRuntimeCodegen(ExprFunctor):
self.nodes = [] self.nodes = []
self.var_map = {} self.var_map = {}
self.params = {} self.params = {}
self.storage_map = None
self.compile_engine = compile_engine.get() self.compile_engine = compile_engine.get()
self.lowered_funcs = set() self.lowered_funcs = set()
self._name_map = {} self._name_map = {}
def add_node(self, node, checked_type): def add_node(self, node, expr):
""" """
Add a node to the graph. Add a node to the graph.
...@@ -116,14 +118,21 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -116,14 +118,21 @@ class GraphRuntimeCodegen(ExprFunctor):
node: Node node: Node
The node to add to the graph. The node to add to the graph.
checked_type: Type expr: tvm.relay.Expr
The type of the node. The corresponding expression.
Returns Returns
------- -------
node_ref: Union[NodeRef, List[NodeRef]] node_ref: Union[NodeRef, List[NodeRef]]
A reference to the node. A reference to the node.
""" """
checked_type = expr.checked_type
# setup storage ids
assert expr in self.storage_map
node.attrs["storage_id"] = [
x.value for x in self.storage_map[expr]
]
node_id = len(self.nodes) node_id = len(self.nodes)
self.nodes.append(node) self.nodes.append(node)
# Tuple return value, flatten as tuple # Tuple return value, flatten as tuple
...@@ -168,7 +177,7 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -168,7 +177,7 @@ class GraphRuntimeCodegen(ExprFunctor):
name = "p%d" % index name = "p%d" % index
self.params[name] = op.data self.params[name] = op.data
node = InputNode(name, {}) node = InputNode(name, {})
return self.add_node(node, op.checked_type) return self.add_node(node, op)
def visit_function(self, _): def visit_function(self, _):
raise RuntimeError("function not supported") raise RuntimeError("function not supported")
...@@ -244,7 +253,7 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -244,7 +253,7 @@ class GraphRuntimeCodegen(ExprFunctor):
op_name = cached_func.func_name op_name = cached_func.func_name
op_node = OpNode(self._get_unique_name(op_name), {}, op_node = OpNode(self._get_unique_name(op_name), {},
op_name, inputs, {}) op_name, inputs, {})
return self.add_node(op_node, call.checked_type) return self.add_node(op_node, call)
def _get_json(self): def _get_json(self):
""" """
...@@ -281,8 +290,7 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -281,8 +290,7 @@ class GraphRuntimeCodegen(ExprFunctor):
assert node.num_outputs == len(node.attrs["shape"]) assert node.num_outputs == len(node.attrs["shape"])
shapes += node.attrs["shape"] shapes += node.attrs["shape"]
dltypes += node.attrs["dtype"] dltypes += node.attrs["dtype"]
for i in range(node.num_outputs): storage_ids += node.attrs["storage_id"]
storage_ids.append(i + num_entry)
num_entry += node.num_outputs num_entry += node.num_outputs
node_row_ptr.append(num_entry) node_row_ptr.append(num_entry)
...@@ -302,6 +310,14 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -302,6 +310,14 @@ class GraphRuntimeCodegen(ExprFunctor):
return json.dumps(json_dict, indent=2) return json.dumps(json_dict, indent=2)
def debug_dump_memory_plan(self, func):
"""Debug function to dump memory plan."""
def _annotate(expr):
if expr in self.storage_map:
return str(self.storage_map[expr])
return ""
return func.astext(show_meta_data=False, annotate=_annotate)
def codegen(self, func): def codegen(self, func):
"""Compile a single function into a graph. """Compile a single function into a graph.
...@@ -321,11 +337,12 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -321,11 +337,12 @@ class GraphRuntimeCodegen(ExprFunctor):
params : Dict[str, tvm.nd.NDArray] params : Dict[str, tvm.nd.NDArray]
Additional constant parameters. Additional constant parameters.
""" """
self.storage_map = _backend.GraphPlanMemory(func)
# First we convert all the parameters into input nodes. # First we convert all the parameters into input nodes.
for param in func.params: for param in func.params:
node = InputNode(param.name_hint, {}) node = InputNode(param.name_hint, {})
self.var_map[param] = self.add_node( self.var_map[param] = self.add_node(
node, param.type_annotation) node, param)
# Then we compile the body into a graph which can depend # Then we compile the body into a graph which can depend
# on input variables. # on input variables.
......
...@@ -23,7 +23,7 @@ def register_relay_node(type_key=None): ...@@ -23,7 +23,7 @@ def register_relay_node(type_key=None):
class RelayNode(NodeBase): class RelayNode(NodeBase):
"""Base class of all relay node.""" """Base class of all relay node."""
def astext(self, annotate=None): def astext(self, show_meta_data=True, annotate=None):
"""Get the text format of the expression. """Get the text format of the expression.
Returns Returns
...@@ -31,11 +31,21 @@ class RelayNode(NodeBase): ...@@ -31,11 +31,21 @@ class RelayNode(NodeBase):
text : str text : str
The text format of the expression. The text format of the expression.
show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.
annotate: Optional[relay.Expr->str] annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional Optional annotate function to provide additional
information in the comment block. information in the comment block.
Note
----
meta data section is necessary to fully parse the text format.
However, it can contain dumps that are big(constat weights),
so it can be helpful to skip printing the meta data section.
""" """
return _expr.RelayPrint(self, annotate) return _expr.RelayPrint(self, show_meta_data, annotate)
@register_relay_node @register_relay_node
......
...@@ -113,6 +113,11 @@ class TextMetaDataContext { ...@@ -113,6 +113,11 @@ class TextMetaDataContext {
return SaveJSON(Array<NodeRef>(meta_data_)); return SaveJSON(Array<NodeRef>(meta_data_));
} }
/*! \return whether the meta data context is empty. */
bool empty() const {
return meta_data_.empty();
}
private: private:
/*! \brief additional metadata stored in TVM json format */ /*! \brief additional metadata stored in TVM json format */
std::vector<NodeRef> meta_data_; std::vector<NodeRef> meta_data_;
...@@ -125,8 +130,9 @@ class TextPrinter : ...@@ -125,8 +130,9 @@ class TextPrinter :
public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*) public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*)
public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*) public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*)
public: public:
explicit TextPrinter(runtime::TypedPackedFunc<std::string(Expr)> annotate) explicit TextPrinter(bool show_meta_data,
: annotate_(annotate) {} runtime::TypedPackedFunc<std::string(Expr)> annotate)
: show_meta_data_(show_meta_data), annotate_(annotate) {}
/*! /*!
* \brief Print a node to string. * \brief Print a node to string.
* \param node. * \param node.
...@@ -144,13 +150,17 @@ class TextPrinter : ...@@ -144,13 +150,17 @@ class TextPrinter :
} else { } else {
stream_ << node; stream_ << node;
} }
std::string meta_json = meta_.GetMetaSection(); if (!meta_.empty()) {
if (meta_json.length() != 0) { if (show_meta_data_) {
// append meta data in the end. std::string meta_json = meta_.GetMetaSection();
stream_ << "# meta data\n" // append meta data in the end.
<< "r\"\"\"\n" stream_ << "# meta data\n"
<< meta_json << "\n" << "r\"\"\"\n"
<< "\"\"\""; << meta_json << "\n"
<< "\"\"\"";
} else {
stream_ << "# meta data omitted. you can use show_meta_data=True to include meta-data\n";
}
} }
return stream_.str(); return stream_.str();
} }
...@@ -227,7 +237,9 @@ class TextPrinter : ...@@ -227,7 +237,9 @@ class TextPrinter :
TextValue id = this->AllocTempVar(); TextValue id = this->AllocTempVar();
this->PrintIndent(); this->PrintIndent();
stream_ << id << " = " << meta_.GetMetaNode(GetRef<NodeRef>(op)); stream_ << id << " = " << meta_.GetMetaNode(GetRef<NodeRef>(op));
this->PrintEndInst("\n"); this->PrintEndInst("");
this->PrintOptionalInfo(GetRef<Expr>(op));
stream_ << '\n';
return id; return id;
} }
...@@ -697,6 +709,8 @@ class TextPrinter : ...@@ -697,6 +709,8 @@ class TextPrinter :
private: private:
class AttrPrinter; class AttrPrinter;
friend class AttrPrinter; friend class AttrPrinter;
/*! \brief Whether to print meta data. */
bool show_meta_data_;
/*! \brief additional comment function */ /*! \brief additional comment function */
runtime::TypedPackedFunc<std::string(Expr)> annotate_; runtime::TypedPackedFunc<std::string(Expr)> annotate_;
/*! \brief meta data context */ /*! \brief meta data context */
...@@ -790,13 +804,14 @@ void TextPrinter::PrintCallAttrs(const Expr& op, ...@@ -790,13 +804,14 @@ void TextPrinter::PrintCallAttrs(const Expr& op,
} }
std::string RelayPrint(const NodeRef& node, std::string RelayPrint(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) { runtime::TypedPackedFunc<std::string(Expr)> annotate) {
return TextPrinter(annotate).Print(node); return TextPrinter(show_meta_data, annotate).Print(node);
} }
TVM_REGISTER_API("relay._expr.RelayPrint") TVM_REGISTER_API("relay._expr.RelayPrint")
.set_body_typed<std::string( .set_body_typed<std::string(
const NodeRef&, const NodeRef&, bool,
runtime::TypedPackedFunc<std::string(Expr)>)>(RelayPrint); runtime::TypedPackedFunc<std::string(Expr)>)>(RelayPrint);
} // namespace relay } // namespace relay
......
...@@ -749,7 +749,7 @@ class FuseMutator : private ExprMutator { ...@@ -749,7 +749,7 @@ class FuseMutator : private ExprMutator {
} }
// Debug function, dump the group assignment in text. // Debug function, dump the group assignment in text.
void DebugDumpGroup(const Expr& body) { void DebugDumpGroup(const Expr& body) {
std::string text = RelayPrint(body, [this](const Expr& expr) -> std::string { std::string text = RelayPrint(body, false, [this](const Expr& expr) -> std::string {
auto it = gmap_.find(expr.get()); auto it = gmap_.find(expr.get());
if (it == gmap_.end()) return ""; if (it == gmap_.end()) return "";
std::ostringstream os; std::ostringstream os;
......
...@@ -77,7 +77,9 @@ def test_add_op_broadcast(): ...@@ -77,7 +77,9 @@ def test_add_op_broadcast():
def test_with_params(): def test_with_params():
x = relay.var('x', shape=(10, 5)) x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(1, 5)) y = relay.var('y', shape=(1, 5))
func = relay.Function([x, y], add(x, y)) z = relay.add(x, y)
z = relay.exp(z)
func = relay.Function([x, y], z)
x_data = np.random.rand(10, 5).astype('float32') x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(1, 5).astype('float32') y_data = np.random.rand(1, 5).astype('float32')
params = {"y": y_data} params = {"y": y_data}
...@@ -87,11 +89,40 @@ def test_with_params(): ...@@ -87,11 +89,40 @@ def test_with_params():
mod.set_input(x=x_data) mod.set_input(x=x_data)
mod.run() mod.run()
res = mod.get_output(0).asnumpy() res = mod.get_output(0).asnumpy()
ref_res = y_data + x_data ref_res = np.exp(y_data + x_data)
tvm.testing.assert_allclose(res, ref_res) tvm.testing.assert_allclose(res, ref_res)
def test_plan_memory():
# it is sufficient to cycle through two memories.
x = relay.var("x", shape=(10,))
y = relay.var("x", shape=(1,))
y2 = relay.exp(y)
z = relay.add(x, y2)
z = relay.exp(z)
z = relay.exp(z)
z = relay.exp(z)
z = relay.exp(z)
z = relay.exp(z)
func = relay.Function([x, y], z)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.fuse_ops(func, opt_level=0)
func = relay.ir_pass.infer_type(func)
smap = relay.backend._backend.GraphPlanMemory(func)
storage_ids = set()
for k, v in smap.items():
for x in v:
storage_ids.add(x.value)
# Current rule requires vars have unique storage id
# because we don't do inplace, we will need another
# two alternating temporary space.
assert len(storage_ids) == 4
if __name__ == "__main__": if __name__ == "__main__":
test_plan_memory()
test_with_params() test_with_params()
test_add_op_scalar() test_add_op_scalar()
test_add_op_tensor() test_add_op_tensor()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment