Commit b9328d02 by Hua Jiang Committed by Thierry Moreau

[VTA] Support network which have no unique operator as start/stop name for graph pack. (#4703)

* [VTA] Support network which have no unique operator as start/stop name
for graph pack.

[Issue]
  Current vta use 'start' and 'stop' name to define the pack start point
  and end point, but this method not work for these network which have
  no 2 unique operator as  start point and stop point.

[Solution]
  In this solution we give 2 addtional parameters start_name_indx and
  stop_name_indx to make vta pack logic work with the said network,
  for exampl for following networks which have no unique operator,

  %0 = nn.add
  %1 = nn.conv2d
  %2 = nn.batch_norm
  %3 = nn.leaky_relu
  %4 = nn.add
  %5 = nn.conv2d
  %6 = nn.batch_norm
  %7 = nn.leaky_relu
  %8 = nn.add

  with this solution we can use following parameter format to make
  vta work on it.

  relay_prog = graph_pack(
                //....
                start_name="nn.add",
                stop_name="nn.add",
                start_name_idx=0,
                stop_name_idx=4)

  to apply on new network, by printing the network we can get index information like following.

  print(mod.astext(show_meta_data=False))
  relay_prog = graph_pack(mod
                          ...
                          start_name="nn.add",
                          stop_name="nn.add",
                          start_name_idx=0,
                          stop_name_idx=4)

* address review comments and fix index count bug

issue:
when do print(mod), the output not only the Call is also have other type
like Var, need add logic to count all except meta.

solution:
add related logic

* address review comments.

* address review comments

* add more detail comments.
parent 23ba37d4
......@@ -110,6 +110,15 @@ def _get_shape(node):
"""
return _to_shape(node.checked_type.shape)
def _operator_idx_inc(expr, count_meta, operator_current_idx):
"""Increase operator index
"""
if isinstance(expr, relay.expr.Constant):
operator_current_idx = operator_current_idx + 1 if count_meta else operator_current_idx
else:
operator_current_idx = operator_current_idx + 1
return operator_current_idx
class ExprPack(ExprMutator):
"""Visitor to perform graph packing on an AST.
"""
......@@ -246,7 +255,7 @@ class ExprPack(ExprMutator):
class BT(Exception):
pass
def get_subgraph(expr, start_name, stop_name):
def get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta):
""" We assume stop_name only appears once for simplicity.
This constraint will be lifted in the future.
bitpack_start and bitpack_end are both inclusive.
......@@ -254,24 +263,32 @@ def get_subgraph(expr, start_name, stop_name):
bitpack_start = op.op.get('annotation.bitpack_start')
bitpack_end = op.op.get('annotation.bitpack_end')
anf = run_opt_pass(expr, transform.ToANormalForm())
def _recursion(anf, start_found, stop_found):
operator_current_idx = 0
def _recursion(anf, start_found, stop_found, operator_current_idx):
""" Helper to obtain the subgraph.
"""
if isinstance(anf, relay.expr.Function):
return relay.expr.Function(anf.params,
_recursion(anf.body, start_found, stop_found),
_recursion(anf.body, start_found, stop_found,
operator_current_idx),
anf.ret_type, anf.type_params, anf.attrs)
elif isinstance(anf, relay.expr.Let):
value = anf.value
if isinstance(value, relay.expr.Call):
if isinstance(value.op, relay.op.Op):
if value.op.name == start_name and not start_found:
value = relay.expr.Call(bitpack_start, [value])
start_found = True
if operator_current_idx == start_name_idx or start_name_idx is None:
value = relay.expr.Call(bitpack_start, [value])
start_found = True
elif value.op.name == stop_name:
raise BT()
if operator_current_idx == stop_name_idx or stop_name_idx is None:
raise BT()
operator_current_idx = _operator_idx_inc(value, count_meta, operator_current_idx)
try:
return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found))
return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found,
operator_current_idx))
except BT:
assert start_found
assert not stop_found
......@@ -283,7 +300,7 @@ def get_subgraph(expr, start_name, stop_name):
assert start_found
assert stop_found
return anf
annotated = _recursion(anf, False, False)
annotated = _recursion(anf, False, False, operator_current_idx)
return run_opt_pass(annotated, transform.ToGraphNormalForm())
def graph_pack(expr,
......@@ -291,7 +308,10 @@ def graph_pack(expr,
cfactor,
weight_bits,
start_name="nn.max_pool2d",
stop_name="nn.global_avg_pool2d"):
stop_name="nn.global_avg_pool2d",
start_name_idx=None,
stop_name_idx=None,
count_meta=False):
"""Pack the graph into batch&channel packed format.
Parameters
......@@ -309,10 +329,24 @@ def graph_pack(expr,
The bit-width of the weights.
start_name: str, optional
Start packing from certain known node.
Start packing from certain known node when start_name_idx is None.
stop_name: str, optional
Stop packing from certain known node.
Stop packing from certain known node when stop_name_idx is None.
start_name_idx: int, optional
When start_name_idx not None, start packing only when node name equal start_name
and node idx equals start_name_idx.
stop_name_idx: int, optional
When stop_name_idx not None, stop packing only when node name equal stop_name
and node index equals stop_name_idx.
count_meta:boolean, optional
When count_meta is False, the operator increase logic would not count the meta that have
the type 'relay.expr.Constant', start_name_idx and stop_name_idx follow the index from
'expr.astext(show_meta_data=False)'. When count_meta is True, the operator increase
logic would count the meta.
Returns
-------
......@@ -320,7 +354,8 @@ def graph_pack(expr,
The transformed expression.
"""
assert isinstance(expr, relay.Function)
expr = get_subgraph(expr, start_name, stop_name)
assert ((start_name != stop_name) or (start_name_idx < stop_name_idx))
expr = get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)
expr = run_opt_pass(expr, transform.InferType())
packer = ExprPack(
bfactor, cfactor,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment