Commit 3770368f by hlu1 Committed by Lianmin Zheng

[Autotvm] Support override (#3292)

parent d43aab07
...@@ -284,7 +284,7 @@ class TaskExtractEnv: ...@@ -284,7 +284,7 @@ class TaskExtractEnv:
return TaskExtractEnv.current return TaskExtractEnv.current
def register_topi_compute(topi_compute, target_keys, template_keys, func=None): def register_topi_compute(topi_compute, target_keys, template_keys, func=None, override=False):
"""Register a tunable template for a topi compute function. """Register a tunable template for a topi compute function.
After the registration, this topi compute will become a configuration dispatcher. It uses After the registration, this topi compute will become a configuration dispatcher. It uses
...@@ -333,7 +333,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None): ...@@ -333,7 +333,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
config_dispatcher = _REGISTERED_DISPATCHER[target_key][topi_compute] config_dispatcher = _REGISTERED_DISPATCHER[target_key][topi_compute]
@config_dispatcher.register(template_keys) @config_dispatcher.register(template_keys, override=override)
def template_call(cfg, *args, **kwargs): def template_call(cfg, *args, **kwargs):
"""call the topi func and attach workload to compute node""" """call the topi func and attach workload to compute node"""
assert not kwargs, "Do not support kwargs in template function call" assert not kwargs, "Do not support kwargs in template function call"
...@@ -372,7 +372,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None): ...@@ -372,7 +372,7 @@ def register_topi_compute(topi_compute, target_keys, template_keys, func=None):
return _decorator return _decorator
def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None): def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None, override=False):
"""Register a tunable template for a topi schedule function. """Register a tunable template for a topi schedule function.
After the registration. This topi schedule will become a configuration dispatcher. It dispatches After the registration. This topi schedule will become a configuration dispatcher. It dispatches
...@@ -438,7 +438,7 @@ def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None) ...@@ -438,7 +438,7 @@ def register_topi_schedule(topi_schedule, target_keys, template_keys, func=None)
config_dispatcher = _REGISTERED_DISPATCHER[target_key][topi_schedule] config_dispatcher = _REGISTERED_DISPATCHER[target_key][topi_schedule]
@config_dispatcher.register(template_keys) @config_dispatcher.register(template_keys, override=override)
def template_call(cfg, outs, *args, **kwargs): def template_call(cfg, outs, *args, **kwargs):
"""call the schedule func""" """call the schedule func"""
if f == topi_schedule.fdefault: if f == topi_schedule.fdefault:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment