Unverified Commit dbd805c1 by Haichen Shen Committed by GitHub

[AutoTVM] Temporary fix to the stack overflow issue in autotvm task extraction (#5019)

* Temporary fix to the stack overflow issue in autotvm task extraction

* fix lint

* fix graph tuner test
parent b91dbca6
......@@ -22,7 +22,8 @@ This module defines the task data structure, as well as a collection(zoo)
of typical tasks of interest.
"""
from .task import Task, create, get_config, args_to_workload, template
from .task import Task, create, get_config, args_to_workload, template, \
serialize_args, deserialize_args
from .space import ConfigSpace, ConfigEntity
from .code_hash import attach_code_hash, attach_code_hash_to_arg
from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest, \
......
......@@ -47,11 +47,20 @@ def _lower(mod,
mod, _ = relay.optimize(mod, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
grc.codegen(mod["main"])
return
# default case
compiler = relay.vm.VMCompiler()
if params:
compiler.set_params(params)
compiler.lower(mod, target=target)
# Try graph codegen first to extract autotvm tasks.
# If failed to compile, then fallback to use VM compiler.
# TODO: Currently VM compiler is likely to stack overflow for large models.
try:
opt_mod, _ = relay.optimize(mod, target, params)
grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
grc.codegen(opt_mod["main"])
except tvm.TVMError:
compiler = relay.vm.VMCompiler()
if params:
compiler.set_params(params)
compiler.lower(mod, target=target)
def extract_from_program(mod, params, target, target_host=None, ops=None):
......
......@@ -34,6 +34,13 @@ from tvm.autotvm.measure import MeasureResult, MeasureInput
from tvm.autotvm.graph_tuner import DPTuner, PBQPTuner
def _create_args(dshape, kshape, strides, padding, dilation, layout, out_layout,
dtype, out_dtype):
data = tvm.te.placeholder(dshape, dtype=dtype)
kernel = tvm.te.placeholder(kshape, dtype=dtype)
return autotvm.task.serialize_args([data, kernel, strides, padding, dilation,
layout, layout, out_dtype])
def _create_data(target, dshape, dtype, layout):
data = relay.var("data", shape=dshape, dtype=dtype)
w0 = relay.var("w0_weight")
......@@ -49,6 +56,12 @@ def _create_data(target, dshape, dtype, layout):
target=target,
params=params,
ops=(relay.op.get("nn.conv2d"),))
new_args = [
_create_args((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 32, 8, 8), (32, 32, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
]
costs = [0.04, 0.012, 0.03]
config_list = []
cfg_dict = {"index": -1,
......@@ -74,7 +87,8 @@ def _create_data(target, dshape, dtype, layout):
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
records = []
for cost, config, task in zip(costs, config_list, tasks):
for args, cost, config, task in zip(new_args, costs, config_list, tasks):
task.args = args
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
......@@ -261,6 +275,12 @@ def test_many_sub_graphs():
target=target,
params=params,
ops=(conv2d,))
new_args = [
_create_args((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 32, 8, 8), (32, 32, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
]
costs = [0.04, 0.012, 0.03, 0.02, 0.02, 0.045]
config_list = []
cfg_dict = {"index": -1,
......@@ -307,9 +327,10 @@ def test_many_sub_graphs():
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
records = []
new_args = new_args + new_args
tasks = tasks + tasks
for cost, config, task in zip(costs, config_list, tasks):
for args, cost, config, task in zip(new_args, costs, config_list, tasks):
task.args = args
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
......@@ -359,6 +380,10 @@ def test_tuple():
target=target,
params=params,
ops=(conv2d,))
new_args = [
_create_args((1, 5, 32, 32), (2, 5, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 5, 32, 32), (3, 5, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
]
costs = [0.01, 0.012, 0.03, 0.04]
config_list = []
cfg_dict = {"index": -1,
......@@ -391,8 +416,10 @@ def test_tuple():
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
records = []
new_args = new_args + new_args
tasks = tasks + tasks
for cost, config, task in zip(costs, config_list, tasks):
for args, cost, config, task in zip(new_args, costs, config_list, tasks):
task.args = args
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
......@@ -444,6 +471,11 @@ def test_triangle_block():
target=target,
params=params,
ops=(conv2d,))
new_args = [
_create_args((1, 3, 8, 8), (16, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 16, 8, 8), (32, 16, 1, 1), (1, 1), (0, 0, 0, 0), (1, 1), layout, layout, dtype, dtype),
_create_args((1, 3, 8, 8), (32, 3, 3, 3), (1, 1), (1, 1, 1, 1), (1, 1), layout, layout, dtype, dtype),
]
costs = [0.04, 0.012, 0.03, 0.02, 0.02, 0.045]
config_list = []
cfg_dict = {"index": -1,
......@@ -490,9 +522,10 @@ def test_triangle_block():
config_list.append(ConfigEntity.from_json_dict(cfg_dict))
records = []
new_args = new_args + new_args
tasks = tasks + tasks
for cost, config, task in zip(costs, config_list, tasks):
for args, cost, config, task in zip(new_args, costs, config_list, tasks):
task.args = args
ms_input = MeasureInput(target=target, task=task, config=config)
ms_output = MeasureResult(costs=(cost,), error_no=0, all_cost=-1, timestamp=-1)
records.append((ms_input, ms_output))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment