relay_integration.py 6.27 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
# pylint: disable=unused-variable,invalid-name
"""
Decorator and utilities for the integration with TOPI and Relay
99.9% copy-paste of implementation by @MerryMercy

"""
import threading
import warnings
import logging


from .task import create
from .topi_integration import TaskExtractEnv

logger = logging.getLogger('autotvm')


34 35
# TODO(moreau89) find a more elegant way to lower for VTAs
def _lower(func,
36 37
           target,
           params):
38
    """ Helper to lower VTA properly.
39 40 41
    """

    from tvm import relay
42
    from tvm.relay.backend import graph_runtime_codegen
43 44 45 46 47

    if hasattr(target, 'device_name') and target.device_name == "vta":
        with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
            import vta
            with vta.build_config():
48 49 50
                mod, _ = relay.optimize(func, target, params)
                grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
                return grc.codegen(mod["main"])
51
    # default case
52 53 54 55
    mod, _ = relay.optimize(func, target, params)
    grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
    return grc.codegen(mod["main"])

56

57 58
def extract_from_program(func, params, ops, target, target_host=None,
                         template_keys=None):
59 60
    """ Extract tuning tasks from a relay program.

61
    This function is the single program version of extract_from_multiple_program.
62 63 64 65 66 67 68 69 70 71 72 73 74

    Parameters
    ----------
    func: relay.expr.Function
        The func to tune
    params: dict of str to numpy array
        The associated parameters of the program
    ops: List of relay op
        List of relay ops to be tuned
    target: tvm.target.Target
        The compilation target
    target_host: tvm.target.Target
        The host compilation target
75 76 77
    template_keys: dict of topi op to str
        The tuning template keys map for schedules, default to None.
        Example: {topi.nn.conv2d: 'direct'}
78 79 80 81 82 83

    Returns
    -------
    task: Array of autotvm.task.Task
        collected tasks
    """
84 85
    return extract_from_multiple_program([func], [params], ops, target, target_host,
                                         template_keys=template_keys)
86 87


88 89
def extract_from_multiple_program(funcs, params, ops, target, target_host=None,
                                  template_keys=None):
90 91
    """ Extract tuning tasks from multiple relay programs.

92 93
    This function collects tuning tasks by building a list of programs
    with a "tracing" target and tracing all the calls to topi.
94 95 96 97 98 99 100 101 102 103 104 105 106

    Parameters
    ----------
    funcs: List of relay.expr.Function
        The list of functions to tune
    params: List of dict of str to numpy array
        The associated parameters of the programs
    ops: List of relay op
        List of relay ops to be tuned
    target: tvm.target.Target
        The compilation target
    target_host: tvm.target.Target
        The host compilation target
107 108 109
    template_keys: dict of topi op to str
        The tuning template keys map for schedules, default to None.
        Example: {topi.nn.conv2d: 'direct'}
110 111 112 113 114 115 116 117 118 119

    Returns
    -------
    task: Array of autotvm.task.Task
        collected tasks
    """
    import tvm.relay.op
    from tvm import relay
    import topi

120 121
    env = TaskExtractEnv.get()

122 123 124 125
    # NOTE: To add more ops, you only need to change the following lists
    # relay op -> topi compute
    OP2TOPI = {
        tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
126
                                 topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc],
127 128
        tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
        tvm.relay.op.nn.dense: [topi.nn.dense],
129
        tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul],
130
        tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
131 132 133 134 135 136 137 138 139 140 141
    }

    topi_funcs = []
    for op_name in ops:
        if op_name in OP2TOPI:
            topi_funcs.extend(OP2TOPI[op_name])
        else:
            warnings.warn("Op %s is not tunable, ignored" % op_name)

    # run compiler to collect all TOPI calls during compilation
    env.reset(topi_funcs)
142 143 144 145 146 147 148 149
    with env:
        # disable logger temporarily
        old_state = logger.disabled
        logger.disabled = True

        for func, param in zip(funcs, params):
            relay.backend.compile_engine.get().clear()
            # wrap build call in thread to avoid multiprocessing problems
150
            mod = relay.Module.from_expr(func)
151 152
            build_thread = threading.Thread(target=_lower,
                                            args=(mod, target, param))
153 154 155 156
            build_thread.start()
            build_thread.join()

        logger.disabled = old_state
157

158 159 160 161 162 163 164 165 166 167
    # convert *topi op to template key* map to *task name to template key* map
    task_name_to_keys = {}
    if template_keys is not None:
        for op in template_keys.keys():
            if op in env.topi_to_task:
                task_name_to_keys[env.topi_to_task[op]] = template_keys[op]
            else:
                logger.warning("Invalid template key, fallback to direct")
                task_name_to_keys[env.topi_to_task[op]] = 'direct'

168 169 170
    # create tasks for target
    tasks = []
    for task_name, args in env.get_tasks():
171
        try:
172
            key = task_name_to_keys[task_name] if task_name in task_name_to_keys else 'direct'
173 174
            tsk = create(task_name, args,
                         target=target, target_host=target_host,
175
                         template_key=key)
176 177
            tasks.append(tsk)
        except topi.InvalidShapeError:
178
            logger.warning("Invalid shape during AutoTVM task creation")
179 180

    return tasks