Unverified Commit c113712d by Tianqi Chen Committed by GitHub

[RELAY][BACKEND] Enable PlanMemory in the graph runtime. (#2120)

parent 6edb3564
......@@ -458,12 +458,14 @@ inline const TTypeNode* ExprNode::type_as() const {
/*!
* \brief Print node as text format.
* \param node The node to be printed.
* \param show_meta_data Whether to print meta data section.
* \param annotate An optional callback function for attaching
* additional comment block to an expr.
* \return The text representation.
*/
std::string RelayPrint(
const NodeRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
} // namespace relay
} // namespace tvm
......
......@@ -55,6 +55,7 @@ def build(funcs, target, target_host=None):
funcs : List[tvm.LoweredFunc]
The list of lowered functions.
target : tvm.Target
The target to run the code on.
......
......@@ -21,6 +21,7 @@ contrib.graph_runtime or any other TVM runtime comptatible system.
from __future__ import absolute_import
import json
import attr
from . import _backend
from . import compile_engine
from ..op import Op
from ..expr import Function, GlobalVar, ExprFunctor
......@@ -103,11 +104,12 @@ class GraphRuntimeCodegen(ExprFunctor):
self.nodes = []
self.var_map = {}
self.params = {}
self.storage_map = None
self.compile_engine = compile_engine.get()
self.lowered_funcs = set()
self._name_map = {}
def add_node(self, node, checked_type):
def add_node(self, node, expr):
"""
Add a node to the graph.
......@@ -116,14 +118,21 @@ class GraphRuntimeCodegen(ExprFunctor):
node: Node
The node to add to the graph.
checked_type: Type
The type of the node.
expr: tvm.relay.Expr
The corresponding expression.
Returns
-------
node_ref: Union[NodeRef, List[NodeRef]]
A reference to the node.
"""
checked_type = expr.checked_type
# setup storage ids
assert expr in self.storage_map
node.attrs["storage_id"] = [
x.value for x in self.storage_map[expr]
]
node_id = len(self.nodes)
self.nodes.append(node)
# Tuple return value, flatten as tuple
......@@ -168,7 +177,7 @@ class GraphRuntimeCodegen(ExprFunctor):
name = "p%d" % index
self.params[name] = op.data
node = InputNode(name, {})
return self.add_node(node, op.checked_type)
return self.add_node(node, op)
def visit_function(self, _):
raise RuntimeError("function not supported")
......@@ -244,7 +253,7 @@ class GraphRuntimeCodegen(ExprFunctor):
op_name = cached_func.func_name
op_node = OpNode(self._get_unique_name(op_name), {},
op_name, inputs, {})
return self.add_node(op_node, call.checked_type)
return self.add_node(op_node, call)
def _get_json(self):
"""
......@@ -281,8 +290,7 @@ class GraphRuntimeCodegen(ExprFunctor):
assert node.num_outputs == len(node.attrs["shape"])
shapes += node.attrs["shape"]
dltypes += node.attrs["dtype"]
for i in range(node.num_outputs):
storage_ids.append(i + num_entry)
storage_ids += node.attrs["storage_id"]
num_entry += node.num_outputs
node_row_ptr.append(num_entry)
......@@ -302,6 +310,14 @@ class GraphRuntimeCodegen(ExprFunctor):
return json.dumps(json_dict, indent=2)
def debug_dump_memory_plan(self, func):
"""Debug function to dump memory plan."""
def _annotate(expr):
if expr in self.storage_map:
return str(self.storage_map[expr])
return ""
return func.astext(show_meta_data=False, annotate=_annotate)
def codegen(self, func):
"""Compile a single function into a graph.
......@@ -321,11 +337,12 @@ class GraphRuntimeCodegen(ExprFunctor):
params : Dict[str, tvm.nd.NDArray]
Additional constant parameters.
"""
self.storage_map = _backend.GraphPlanMemory(func)
# First we convert all the parameters into input nodes.
for param in func.params:
node = InputNode(param.name_hint, {})
self.var_map[param] = self.add_node(
node, param.type_annotation)
node, param)
# Then we compile the body into a graph which can depend
# on input variables.
......
......@@ -23,7 +23,7 @@ def register_relay_node(type_key=None):
class RelayNode(NodeBase):
"""Base class of all relay node."""
def astext(self, annotate=None):
def astext(self, show_meta_data=True, annotate=None):
"""Get the text format of the expression.
Returns
......@@ -31,11 +31,21 @@ class RelayNode(NodeBase):
text : str
The text format of the expression.
show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.
annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
information in the comment block.
Note
----
meta data section is necessary to fully parse the text format.
However, it can contain dumps that are big(constat weights),
so it can be helpful to skip printing the meta data section.
"""
return _expr.RelayPrint(self, annotate)
return _expr.RelayPrint(self, show_meta_data, annotate)
@register_relay_node
......
......@@ -113,6 +113,11 @@ class TextMetaDataContext {
return SaveJSON(Array<NodeRef>(meta_data_));
}
/*! \return whether the meta data context is empty. */
bool empty() const {
return meta_data_.empty();
}
private:
/*! \brief additional metadata stored in TVM json format */
std::vector<NodeRef> meta_data_;
......@@ -125,8 +130,9 @@ class TextPrinter :
public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*)
public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*)
public:
explicit TextPrinter(runtime::TypedPackedFunc<std::string(Expr)> annotate)
: annotate_(annotate) {}
explicit TextPrinter(bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate)
: show_meta_data_(show_meta_data), annotate_(annotate) {}
/*!
* \brief Print a node to string.
* \param node.
......@@ -144,13 +150,17 @@ class TextPrinter :
} else {
stream_ << node;
}
std::string meta_json = meta_.GetMetaSection();
if (meta_json.length() != 0) {
// append meta data in the end.
stream_ << "# meta data\n"
<< "r\"\"\"\n"
<< meta_json << "\n"
<< "\"\"\"";
if (!meta_.empty()) {
if (show_meta_data_) {
std::string meta_json = meta_.GetMetaSection();
// append meta data in the end.
stream_ << "# meta data\n"
<< "r\"\"\"\n"
<< meta_json << "\n"
<< "\"\"\"";
} else {
stream_ << "# meta data omitted. you can use show_meta_data=True to include meta-data\n";
}
}
return stream_.str();
}
......@@ -227,7 +237,9 @@ class TextPrinter :
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = " << meta_.GetMetaNode(GetRef<NodeRef>(op));
this->PrintEndInst("\n");
this->PrintEndInst("");
this->PrintOptionalInfo(GetRef<Expr>(op));
stream_ << '\n';
return id;
}
......@@ -697,6 +709,8 @@ class TextPrinter :
private:
class AttrPrinter;
friend class AttrPrinter;
/*! \brief Whether to print meta data. */
bool show_meta_data_;
/*! \brief additional comment function */
runtime::TypedPackedFunc<std::string(Expr)> annotate_;
/*! \brief meta data context */
......@@ -790,13 +804,14 @@ void TextPrinter::PrintCallAttrs(const Expr& op,
}
std::string RelayPrint(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
return TextPrinter(annotate).Print(node);
return TextPrinter(show_meta_data, annotate).Print(node);
}
TVM_REGISTER_API("relay._expr.RelayPrint")
.set_body_typed<std::string(
const NodeRef&,
const NodeRef&, bool,
runtime::TypedPackedFunc<std::string(Expr)>)>(RelayPrint);
} // namespace relay
......
......@@ -749,7 +749,7 @@ class FuseMutator : private ExprMutator {
}
// Debug function, dump the group assignment in text.
void DebugDumpGroup(const Expr& body) {
std::string text = RelayPrint(body, [this](const Expr& expr) -> std::string {
std::string text = RelayPrint(body, false, [this](const Expr& expr) -> std::string {
auto it = gmap_.find(expr.get());
if (it == gmap_.end()) return "";
std::ostringstream os;
......
......@@ -77,7 +77,9 @@ def test_add_op_broadcast():
def test_with_params():
x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(1, 5))
func = relay.Function([x, y], add(x, y))
z = relay.add(x, y)
z = relay.exp(z)
func = relay.Function([x, y], z)
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(1, 5).astype('float32')
params = {"y": y_data}
......@@ -87,11 +89,40 @@ def test_with_params():
mod.set_input(x=x_data)
mod.run()
res = mod.get_output(0).asnumpy()
ref_res = y_data + x_data
ref_res = np.exp(y_data + x_data)
tvm.testing.assert_allclose(res, ref_res)
def test_plan_memory():
# it is sufficient to cycle through two memories.
x = relay.var("x", shape=(10,))
y = relay.var("x", shape=(1,))
y2 = relay.exp(y)
z = relay.add(x, y2)
z = relay.exp(z)
z = relay.exp(z)
z = relay.exp(z)
z = relay.exp(z)
z = relay.exp(z)
func = relay.Function([x, y], z)
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.fuse_ops(func, opt_level=0)
func = relay.ir_pass.infer_type(func)
smap = relay.backend._backend.GraphPlanMemory(func)
storage_ids = set()
for k, v in smap.items():
for x in v:
storage_ids.add(x.value)
# Current rule requires vars have unique storage id
# because we don't do inplace, we will need another
# two alternating temporary space.
assert len(storage_ids) == 4
if __name__ == "__main__":
test_plan_memory()
test_with_params()
test_add_op_scalar()
test_add_op_tensor()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment