relay_integration.py 6.93 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
# pylint: disable=unused-variable,invalid-name
"""
Decorator and utilities for the integration with TOPI and Relay
99.9% copy-paste of implementation by @MerryMercy

"""
import threading
import warnings
import logging


from .task import create
from .topi_integration import TaskExtractEnv

logger = logging.getLogger('autotvm')


34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
# TODO(moreau89) find a more elegant way to build for VTAs
def _build(func,
           target,
           target_host,
           params):
    """ Helper to build VTA properly.
    """

    from tvm import relay

    if hasattr(target, 'device_name') and target.device_name == "vta":
        with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
            import vta
            with vta.build_config():
                return relay.build(func, target, target_host, params)
    # default case
    return relay.build(func, target, target_host, params)

52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
def extract_from_program(func, params, ops, target, target_host=None):
    """ Extract tuning tasks from a relay program.

    This function collects tuning tasks by building the program
    with a "tracing" target and tracing all the calls to topi.

    Parameters
    ----------
    func: relay.expr.Function
        The func to tune
    params: dict of str to numpy array
        The associated parameters of the program
    ops: List of relay op
        List of relay ops to be tuned
    target: tvm.target.Target
        The compilation target
    target_host: tvm.target.Target
        The host compilation target

    Returns
    -------
    task: Array of autotvm.task.Task
        collected tasks
    """
    import tvm.relay.op
    from tvm import relay
    import topi

80 81
    env = TaskExtractEnv.get()

82 83 84 85
    # NOTE: To add more ops, you only need to change the following lists
    # relay op -> topi compute
    OP2TOPI = {
        tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
86
                                 topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc],
87 88
        tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
        tvm.relay.op.nn.dense: [topi.nn.dense],
89
        tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
90 91 92 93 94 95 96 97 98 99 100
    }

    topi_funcs = []
    for op_name in ops:
        if op_name in OP2TOPI:
            topi_funcs.extend(OP2TOPI[op_name])
        else:
            warnings.warn("Op %s is not tunable, ignored" % op_name)

    # run compiler to collect all TOPI calls during compilation
    env.reset(topi_funcs)
101 102 103 104
    with env:
        # disable logger temporarily
        old_state = logger.disabled
        logger.disabled = True
105

106 107
        relay.backend.compile_engine.get().clear()
        # wrap build call in thread to avoid multiprocessing problems
108
        mod = relay.Module.from_expr(func)
109
        build_thread = threading.Thread(target=_build,
110
                                        args=(mod,
111 112 113 114 115 116 117
                                              target,
                                              target_host,
                                              params))
        build_thread.start()
        build_thread.join()

        logger.disabled = old_state
118 119 120 121

    # create tasks for target
    tasks = []
    for task_name, args in env.get_tasks():
122 123 124 125 126 127 128
        try:
            tsk = create(task_name, args,
                         target=target, target_host=target_host,
                         template_key='direct')
            tasks.append(tsk)
        except topi.InvalidShapeError:
            warnings.warn("Invalid shape during AutoTVM task creation")
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
    return tasks


def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
    """ Extract tuning tasks from multiple relay programs.

    This function is the multiple program version of extract_from_program

    Parameters
    ----------
    funcs: List of relay.expr.Function
        The list of functions to tune
    params: List of dict of str to numpy array
        The associated parameters of the programs
    ops: List of relay op
        List of relay ops to be tuned
    target: tvm.target.Target
        The compilation target
    target_host: tvm.target.Target
        The host compilation target

    Returns
    -------
    task: Array of autotvm.task.Task
        collected tasks
    """
    env = TaskExtractEnv.get()
    import tvm.relay.op
    from tvm import relay
    import topi

    # NOTE: To add more ops, you only need to change the following lists
    # relay op -> topi compute
    OP2TOPI = {
        tvm.relay.op.nn.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
                                 topi.nn.group_conv2d_nchw],
        tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
        tvm.relay.op.nn.dense: [topi.nn.dense],
167
        tvm.relay.op.nn.contrib_deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
168 169 170 171 172 173 174 175 176 177 178
    }

    topi_funcs = []
    for op_name in ops:
        if op_name in OP2TOPI:
            topi_funcs.extend(OP2TOPI[op_name])
        else:
            warnings.warn("Op %s is not tunable, ignored" % op_name)

    # run compiler to collect all TOPI calls during compilation
    env.reset(topi_funcs)
179 180 181 182 183 184 185 186
    with env:
        # disable logger temporarily
        old_state = logger.disabled
        logger.disabled = True

        for func, param in zip(funcs, params):
            relay.backend.compile_engine.get().clear()
            # wrap build call in thread to avoid multiprocessing problems
187
            mod = relay.Module.from_expr(func)
188
            build_thread = threading.Thread(target=my_build,
189
                                            args=(mod,
190 191 192 193 194 195 196
                                                  target,
                                                  target_host,
                                                  params))
            build_thread.start()
            build_thread.join()

        logger.disabled = old_state
197 198 199 200

    # create tasks for target
    tasks = []
    for task_name, args in env.get_tasks():
201 202 203 204 205 206 207
        try:
            tsk = create(task_name, args,
                         target=target, target_host=target_host,
                         template_key='direct')
            tasks.append(tsk)
        except topi.InvalidShapeError:
            print("[Warning] Invalid shape during AutoTVM task creation")
208 209

    return tasks