Commit b9328d02 by Hua Jiang Committed by Thierry Moreau

[VTA] Support network which have no unique operator as start/stop name for graph pack. (#4703)

* [VTA] Support network which have no unique operator as start/stop name
for graph pack.

[Issue]
  Current vta use 'start' and 'stop' name to define the pack start point
  and end point, but this method not work for these network which have
  no 2 unique operator as  start point and stop point.

[Solution]
  In this solution we give 2 addtional parameters start_name_indx and
  stop_name_indx to make vta pack logic work with the said network,
  for exampl for following networks which have no unique operator,

  %0 = nn.add
  %1 = nn.conv2d
  %2 = nn.batch_norm
  %3 = nn.leaky_relu
  %4 = nn.add
  %5 = nn.conv2d
  %6 = nn.batch_norm
  %7 = nn.leaky_relu
  %8 = nn.add

  with this solution we can use following parameter format to make
  vta work on it.

  relay_prog = graph_pack(
                //....
                start_name="nn.add",
                stop_name="nn.add",
                start_name_idx=0,
                stop_name_idx=4)

  to apply on new network, by printing the network we can get index information like following.

  print(mod.astext(show_meta_data=False))
  relay_prog = graph_pack(mod
                          ...
                          start_name="nn.add",
                          stop_name="nn.add",
                          start_name_idx=0,
                          stop_name_idx=4)

* address review comments and fix index count bug

issue:
when do print(mod), the output not only the Call is also have other type
like Var, need add logic to count all except meta.

solution:
add related logic

* address review comments.

* address review comments

* add more detail comments.
parent 23ba37d4
...@@ -110,6 +110,15 @@ def _get_shape(node): ...@@ -110,6 +110,15 @@ def _get_shape(node):
""" """
return _to_shape(node.checked_type.shape) return _to_shape(node.checked_type.shape)
def _operator_idx_inc(expr, count_meta, operator_current_idx):
"""Increase operator index
"""
if isinstance(expr, relay.expr.Constant):
operator_current_idx = operator_current_idx + 1 if count_meta else operator_current_idx
else:
operator_current_idx = operator_current_idx + 1
return operator_current_idx
class ExprPack(ExprMutator): class ExprPack(ExprMutator):
"""Visitor to perform graph packing on an AST. """Visitor to perform graph packing on an AST.
""" """
...@@ -246,7 +255,7 @@ class ExprPack(ExprMutator): ...@@ -246,7 +255,7 @@ class ExprPack(ExprMutator):
class BT(Exception): class BT(Exception):
pass pass
def get_subgraph(expr, start_name, stop_name): def get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta):
""" We assume stop_name only appears once for simplicity. """ We assume stop_name only appears once for simplicity.
This constraint will be lifted in the future. This constraint will be lifted in the future.
bitpack_start and bitpack_end are both inclusive. bitpack_start and bitpack_end are both inclusive.
...@@ -254,24 +263,32 @@ def get_subgraph(expr, start_name, stop_name): ...@@ -254,24 +263,32 @@ def get_subgraph(expr, start_name, stop_name):
bitpack_start = op.op.get('annotation.bitpack_start') bitpack_start = op.op.get('annotation.bitpack_start')
bitpack_end = op.op.get('annotation.bitpack_end') bitpack_end = op.op.get('annotation.bitpack_end')
anf = run_opt_pass(expr, transform.ToANormalForm()) anf = run_opt_pass(expr, transform.ToANormalForm())
def _recursion(anf, start_found, stop_found): operator_current_idx = 0
def _recursion(anf, start_found, stop_found, operator_current_idx):
""" Helper to obtain the subgraph. """ Helper to obtain the subgraph.
""" """
if isinstance(anf, relay.expr.Function): if isinstance(anf, relay.expr.Function):
return relay.expr.Function(anf.params, return relay.expr.Function(anf.params,
_recursion(anf.body, start_found, stop_found), _recursion(anf.body, start_found, stop_found,
operator_current_idx),
anf.ret_type, anf.type_params, anf.attrs) anf.ret_type, anf.type_params, anf.attrs)
elif isinstance(anf, relay.expr.Let): elif isinstance(anf, relay.expr.Let):
value = anf.value value = anf.value
if isinstance(value, relay.expr.Call): if isinstance(value, relay.expr.Call):
if isinstance(value.op, relay.op.Op): if isinstance(value.op, relay.op.Op):
if value.op.name == start_name and not start_found: if value.op.name == start_name and not start_found:
value = relay.expr.Call(bitpack_start, [value]) if operator_current_idx == start_name_idx or start_name_idx is None:
start_found = True value = relay.expr.Call(bitpack_start, [value])
start_found = True
elif value.op.name == stop_name: elif value.op.name == stop_name:
raise BT() if operator_current_idx == stop_name_idx or stop_name_idx is None:
raise BT()
operator_current_idx = _operator_idx_inc(value, count_meta, operator_current_idx)
try: try:
return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found)) return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found,
operator_current_idx))
except BT: except BT:
assert start_found assert start_found
assert not stop_found assert not stop_found
...@@ -283,7 +300,7 @@ def get_subgraph(expr, start_name, stop_name): ...@@ -283,7 +300,7 @@ def get_subgraph(expr, start_name, stop_name):
assert start_found assert start_found
assert stop_found assert stop_found
return anf return anf
annotated = _recursion(anf, False, False) annotated = _recursion(anf, False, False, operator_current_idx)
return run_opt_pass(annotated, transform.ToGraphNormalForm()) return run_opt_pass(annotated, transform.ToGraphNormalForm())
def graph_pack(expr, def graph_pack(expr,
...@@ -291,7 +308,10 @@ def graph_pack(expr, ...@@ -291,7 +308,10 @@ def graph_pack(expr,
cfactor, cfactor,
weight_bits, weight_bits,
start_name="nn.max_pool2d", start_name="nn.max_pool2d",
stop_name="nn.global_avg_pool2d"): stop_name="nn.global_avg_pool2d",
start_name_idx=None,
stop_name_idx=None,
count_meta=False):
"""Pack the graph into batch&channel packed format. """Pack the graph into batch&channel packed format.
Parameters Parameters
...@@ -309,10 +329,24 @@ def graph_pack(expr, ...@@ -309,10 +329,24 @@ def graph_pack(expr,
The bit-width of the weights. The bit-width of the weights.
start_name: str, optional start_name: str, optional
Start packing from certain known node. Start packing from certain known node when start_name_idx is None.
stop_name: str, optional stop_name: str, optional
Stop packing from certain known node. Stop packing from certain known node when stop_name_idx is None.
start_name_idx: int, optional
When start_name_idx not None, start packing only when node name equal start_name
and node idx equals start_name_idx.
stop_name_idx: int, optional
When stop_name_idx not None, stop packing only when node name equal stop_name
and node index equals stop_name_idx.
count_meta:boolean, optional
When count_meta is False, the operator increase logic would not count the meta that have
the type 'relay.expr.Constant', start_name_idx and stop_name_idx follow the index from
'expr.astext(show_meta_data=False)'. When count_meta is True, the operator increase
logic would count the meta.
Returns Returns
------- -------
...@@ -320,7 +354,8 @@ def graph_pack(expr, ...@@ -320,7 +354,8 @@ def graph_pack(expr,
The transformed expression. The transformed expression.
""" """
assert isinstance(expr, relay.Function) assert isinstance(expr, relay.Function)
expr = get_subgraph(expr, start_name, stop_name) assert ((start_name != stop_name) or (start_name_idx < stop_name_idx))
expr = get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)
expr = run_opt_pass(expr, transform.InferType()) expr = run_opt_pass(expr, transform.InferType())
packer = ExprPack( packer = ExprPack(
bfactor, cfactor, bfactor, cfactor,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment