Unverified Commit d2799915 by Animesh Jain Committed by GitHub

[AutoTVM] Minor bug fixes in AutoTVM for QNN graphs (#4797)

* [AutoTVM] Minor bug fixes in AutoTVM for QNN graphs.

* Bring back strided_slice.

* Replace tvm.nd change.
parent 3fb937fe
...@@ -126,10 +126,10 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list): ...@@ -126,10 +126,10 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
for i, input_idx in enumerate(node_entry["inputs"]): for i, input_idx in enumerate(node_entry["inputs"]):
input_node_entry = node_list[input_idx[0]] input_node_entry = node_list[input_idx[0]]
input_type = input_node_entry["types"][input_idx[1]] input_type = input_node_entry["types"][input_idx[1]]
if not isinstance(input_node_entry["node"], (Var, Call)): if not isinstance(input_node_entry["node"], (Var, Constant, Call)):
raise RuntimeError("Graph tuner can only tune target " raise RuntimeError("Graph tuner can only tune target "
"operators with input node of type " "operators with input node of type "
"relay.expr.Var or relay.expr.Call. Now " "relay.expr.Var/Constant/Call. Now "
"find a target op %s with input type %s" "find a target op %s with input type %s"
% (op_name, str(type(input_node_entry["node"])))) % (op_name, str(type(input_node_entry["node"]))))
free_var = relay.Var("var_%d" % i, input_type) free_var = relay.Var("var_%d" % i, input_type)
...@@ -167,7 +167,8 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list): ...@@ -167,7 +167,8 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list):
else: else:
node_entry["inputs"].append([in_node_idx, 0, 0]) node_entry["inputs"].append([in_node_idx, 0, 0])
elif isinstance(node, Constant): elif isinstance(node, Constant):
pass node_entry["name"] = "Constant_" + str(node_index)
node_entry["types"] = [node.checked_type]
elif isinstance(node, relay.op.op.Op): elif isinstance(node, relay.op.op.Op):
return return
else: else:
......
...@@ -50,6 +50,7 @@ def _lower(mod, ...@@ -50,6 +50,7 @@ def _lower(mod,
grc.codegen(mod["main"]) grc.codegen(mod["main"])
# default case # default case
compiler = relay.vm.VMCompiler() compiler = relay.vm.VMCompiler()
if params:
compiler.set_params(params) compiler.set_params(params)
compiler.lower(mod, target=target) compiler.lower(mod, target=target)
...@@ -123,7 +124,9 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None, ...@@ -123,7 +124,9 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None,
# relay op -> topi compute # relay op -> topi compute
OP2TOPI = { OP2TOPI = {
tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw, tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc], topi.nn.group_conv2d_nchw,
topi.nn.conv2d_NCHWc,
topi.nn.conv2d_NCHWc_int8],
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw], tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
tvm.relay.op.nn.dense: [topi.nn.dense], tvm.relay.op.nn.dense: [topi.nn.dense],
tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul], tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment