""" A compiler from a Relay expression to TVM's graph runtime. The compiler is built from a few pieces. First we define a compiler from a single Relay expression to the graph langauge. We require the expression to be a function. The function's parameters correpond to the placeholder/inputs and model parameters found in the computation graph representation. The body of the function represents the computation graph. The compiler's output is a program in the graph language, which is composed of graph langauge is composed of Node, NodeRef, InputNode, OpNode. This "little language" represents programs in TVM's graph format. To connect to the graph runtime, we use a printer that converts our graph format into TVM's JSON format. The resulting string can be loaded by contrib.graph_runtime or any other TVM runtime comptatible system. """ from __future__ import absolute_import import json from collections import defaultdict, OrderedDict import attr from . import _backend from . import compile_engine from ..op import Op from ..expr import Function, GlobalVar from ..expr_functor import ExprFunctor from ..ty import TupleType, TensorType from ... import target as _target @attr.s class NodeRef(object): """A reference to a node, used for constructing the graph.""" ident = attr.ib() index = attr.ib(default=0) version = attr.ib(default=0) def to_json(self): return [self.ident, self.index, self.version] @attr.s class Node(object): """The base class for nodes in the TVM runtime system graph input.""" name = attr.ib() attrs = attr.ib() def to_json(self): raise Exception("Abstract method, please implement me.") @attr.s class InputNode(Node): """An input node in the TVM runtime system graph input.""" name = attr.ib() attrs = attr.ib() def to_json(self): return { "op": "null", "name": self.name, "inputs": [] } @attr.s class OpNode(Node): """An operator node in the TVM runtime system"s graph input.""" op_name = attr.ib() inputs = attr.ib() op_attrs = attr.ib() num_outputs = attr.ib(default=1) def to_json(self): attrs = dict.copy(self.op_attrs) # Extend ops with extra info. attrs["func_name"] = self.op_name attrs["flatten_data"] = "0" attrs["num_inputs"] = str(len(self.inputs)) attrs["num_outputs"] = str(self.num_outputs) return { "op": "tvm_op", "name": self.name, "attrs": attrs, "inputs": self.inputs } def shape_to_json(shape): """Convert symbolic shape to json compatible forma.""" return [sh.value for sh in shape] class GraphRuntimeCodegen(ExprFunctor): """The compiler from Relay to the TVM runtime system.""" nodes = attr.ib() var_map = attr.ib() def __init__(self, mod, target): ExprFunctor.__init__(self) self.mod = mod self.target = target self.nodes = [] self.var_map = {} self.params = {} self.storage_device_map = None self.compile_engine = compile_engine.get() self.lowered_funcs = defaultdict(set) self._name_map = {} def add_node(self, node, expr): """ Add a node to the graph. Parameters ---------- node: Node The node to add to the graph. expr: tvm.relay.Expr The corresponding expression. Returns ------- node_ref: Union[NodeRef, List[NodeRef]] A reference to the node. """ checked_type = expr.checked_type # setup storage ids assert expr in self.storage_device_map storage_device_info = self.storage_device_map[expr] assert len(storage_device_info) == 2 node.attrs["storage_id"] = [x.value for x in storage_device_info[0]] device_types = [x.value for x in storage_device_info[1]] num_unknown_devices = device_types.count(0) if num_unknown_devices != 0 and num_unknown_devices != len(device_types): raise RuntimeError("The graph contains not annotated nodes for " "heterogeneous execution. All nodes must be " "annotated.") # Add the `device_index` attribute when the graph is annotated. if num_unknown_devices == 0: node.attrs["device_index"] = device_types node_id = len(self.nodes) self.nodes.append(node) # Tuple return value, flatten as tuple if isinstance(checked_type, TupleType): ret = [] shape = [] dtype = [] for i, typ in enumerate(checked_type.fields): if not isinstance(typ, TensorType): raise RuntimeError("type %s not supported" % typ) ret.append(NodeRef(node_id, i)) shape.append(shape_to_json(typ.shape)) dtype.append(typ.dtype) node.attrs["shape"] = shape node.attrs["dtype"] = dtype assert isinstance(node, OpNode) node.num_outputs = len(checked_type.fields) return tuple(ret) # Normal tensor return type if not isinstance(checked_type, TensorType): raise RuntimeError("type %s not supported" % checked_type) node.attrs["shape"] = [shape_to_json(checked_type.shape)] node.attrs["dtype"] = [checked_type.dtype] node.num_outputs = 1 return NodeRef(node_id, 0) def visit_tuple(self, vtuple): fields = [] for field in vtuple.fields: ref = self.visit(field) assert isinstance(ref, NodeRef) fields.append(ref) return tuple(fields) def visit_tuple_getitem(self, op): vtuple = self.visit(op.tuple_value) assert isinstance(vtuple, tuple) return vtuple[op.index] def visit_constant(self, op): index = len(self.params) name = "p%d" % index self.params[name] = op.data node = InputNode(name, {}) return self.add_node(node, op) def visit_function(self, _): raise RuntimeError("function not supported") def visit_if(self, _): raise RuntimeError("if not supported") def visit_global_var(self, _): raise RuntimeError() def visit_let(self, let): """ Visit the let binding, by first traversing its value, then setting the metadata on the returned NodeRef. Finally visit the body, and return the NodeRef corresponding to it. Parameters ---------- let: tvm.relay.Expr The let binding to transform. Returns ------- ref: NodeRef The node reference to the body. """ assert let.var not in self.var_map self.var_map[let.var] = self.visit(let.value) return self.visit(let.body) def visit_var(self, rvar): return self.var_map[rvar] def visit_call(self, call): """Transform a ::tvm.relay.Call into an operator in the TVM graph.""" if isinstance(call.op, Op): raise Exception( "Operators should be transformed away; try applying" + "the fuse_ops transformation to the expression.") elif isinstance(call.op, GlobalVar): func = self.mod[call.op] elif isinstance(call.op, Function): func = call.op else: raise Exception( "TVM runtime does not support calls to {0}".format(type(call.op))) if int(func.attrs.Primitive) != 1: raise Exception( "TVM only support calls to primitive functions " + "(i.e functions composed of fusable operator invocations)") assert call in self.storage_device_map device_types = self.storage_device_map[call][1] call_dev_type = device_types[0].value if isinstance(self.target, (str, _target.Target)): # homogeneous execution. cached_func = self.compile_engine.lower(func, self.target) self.target = {0: str(self.target)} elif isinstance(self.target, dict): # heterogeneous execution. if call_dev_type not in self.target: raise Exception("No target is provided for device " + "{0}".format(call_dev_type)) cached_func = self.compile_engine.lower(func, self.target[call_dev_type]) else: raise ValueError("self.target must be the type of str," + "tvm.target.Target, or dict of int to str") for loweredf in cached_func.funcs: self.lowered_funcs[self.target[call_dev_type]].add(loweredf) inputs = [] # flatten tuple in the call. for arg in call.args: res = self.visit(arg) if isinstance(arg.checked_type, TupleType): assert isinstance(res, tuple) inputs += res else: inputs.append(res) inputs = [x.to_json() for x in inputs] op_name = cached_func.func_name op_node = OpNode(self._get_unique_name(op_name), {}, op_name, inputs, {}) return self.add_node(op_node, call) def visit_op(self, _): raise Exception("can not compile op in non-eta expanded form") def visit_ref_create(self, _): raise RuntimeError("reference not supported") def visit_ref_read(self, _): raise RuntimeError("reference not supported") def visit_ref_write(self, _): raise RuntimeError("reference not supported") def visit_constructor(self, _): raise Exception("ADT constructor case not yet implemented") def visit_match(self, _): raise Exception("match case not yet implemented") def _get_json(self): """ Convert the sequence of nodes stored by the compiler into the TVM graph runtime format. Returns ------- graph_json : str The generated JSON as a string. """ nodes = [] # First we compute "nodes" field. for node in self.nodes: nodes.append(node.to_json()) arg_nodes = [] # Compute "arg_nodes" and "heads" fields. for i, node in enumerate(self.nodes): if isinstance(node, InputNode): arg_nodes.append(i) heads = self.heads heads = heads if isinstance(heads, tuple) else [heads] heads = [x.to_json() for x in heads] # Compute "node_row_ptr" and entry attributes. num_entry = 0 shapes = [] storage_ids = [] device_types = [] dltypes = [] node_row_ptr = [0] for node in self.nodes: assert node.num_outputs == len(node.attrs["shape"]) shapes += node.attrs["shape"] dltypes += node.attrs["dtype"] storage_ids += node.attrs["storage_id"] if "device_index" in node.attrs: device_types += node.attrs["device_index"] num_entry += node.num_outputs node_row_ptr.append(num_entry) # Compute "attrs" field. attrs = {} attrs["shape"] = ["list_shape", shapes] attrs["storage_id"] = ["list_int", storage_ids] if device_types: attrs["device_index"] = ["list_int", device_types] attrs["dltype"] = ["list_str", dltypes] # Metadata definitions def nested_defaultdict(): return defaultdict(nested_defaultdict) metadata = nested_defaultdict() for node_id in arg_nodes: node_name = nodes[node_id]['name'] if node_name not in self.params: metadata['signatures']['default']['inputs'][node_name]['id'] = node_id metadata['signatures']['default']['inputs'][node_name]['dtype'] = dltypes[node_id] metadata['signatures']['default']['inputs'][node_name]['shape'] = shapes[node_id] for node_id in heads: node_name = nodes[node_id[0]]['name'] metadata['signatures']['default']['outputs'][node_name]['id'] = node_id[0] metadata['signatures']['default']['outputs'][node_name]['dtype'] = dltypes[node_id[0]] metadata['signatures']['default']['outputs'][node_name]['shape'] = shapes[node_id[0]] # Keep 'metadata' always at end json_dict = OrderedDict([ ("nodes", nodes), ("arg_nodes", arg_nodes), ("heads", heads), ("attrs", attrs), ("node_row_ptr", node_row_ptr), ("metadata", metadata), ]) return json.dumps(json_dict, indent=2) def debug_dump_memory_plan(self, func): """Debug function to dump memory plan.""" def _annotate(expr): if expr in self.storage_device_map: storage_device_info = self.storage_device_map[expr] assert len(storage_device_info) == 2 return str(storage_device_info[0]) return "" return func.astext(show_meta_data=False, annotate=_annotate) def debug_dump_device_annotation(self, func): """Debug function to dump device annotation result.""" def _annotate(expr): if expr in self.storage_device_map: storage_device_info = self.storage_device_map[expr] assert len(storage_device_info) == 2 return str(storage_device_info[1]) return "" return func.astext(show_meta_data=False, annotate=_annotate) def codegen(self, func): """Compile a single function into a graph. Parameters ---------- func: tvm.relay.Expr The function to compile. Returns ------- graph_json : str The graph json that can be consumed by runtime. lowered_funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]] The lowered functions. params : Dict[str, tvm.nd.NDArray] Additional constant parameters. """ self.storage_device_map = _backend.GraphPlanMemory(func) # First we convert all the parameters into input nodes. for param in func.params: node = InputNode(param.name_hint, {}) self.var_map[param] = self.add_node(node, param) # Then we compile the body into a graph which can depend # on input variables. self.heads = self.visit(func.body) graph_json = self._get_json() # Return the lowered functions as a list for homogeneous compilation. # Otherwise, for heterogeneous compilation, a dictionary containing # the device id to a list of lowered functions is returned. Both forms # are acceptable to tvm.build. if not isinstance(self.target, dict): lowered_funcs = list(list(self.lowered_funcs.values())[0]) else: lowered_funcs = {k: list(v) for k, v in self.lowered_funcs.items()} return graph_json, lowered_funcs, self.params def _get_unique_name(self, name): if name not in self._name_map: self._name_map[name] = 1 return name index = self._name_map[name] self._name_map[name] += 1 return self._get_unique_name(name + str(index))