nnvm_integration.py 11.2 KB
Newer Older
1 2 3 4 5 6
# pylint: disable=unused-variable,invalid-name
"""
Decorator and utilities for the integration with TOPI and NNVM

"""
import warnings
7 8
import logging

9

10
from ... import tensor, placeholder, create_schedule, target as _target
11 12 13 14

from ..util import get_const_tuple
from .task import create, register

15
logger = logging.getLogger('autotvm')
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

def serialize_args(args):
    """serialize arguments of a topi function to a hashable tuple.

    Parameters
    ----------
    args: list of hashable or Tensor
    """
    ret = []
    for t in args:
        if isinstance(t, tensor.Tensor):
            ret.append(('TENSOR', get_const_tuple(t.shape), t.dtype))
        else:
            ret.append(t)
    return tuple(ret)


def deserialize_args(args):
    """The inverse function of :code:`serialize_args`.

    Parameters
    ----------
    args: list of hashable or Tensor
    """
    ret = []
    for t in args:
        if isinstance(t, tuple) and t[0] == 'TENSOR':
            ret.append(placeholder(shape=t[1], dtype=t[2]))
        else:
            ret.append(t)
    return ret


# Task extractor for nnvm graph
class TaskExtractEnv:
    """Global environment for extracting tuning tasks from nnvm graph"""
    current = None

    def __init__(self):
        import topi
        import nnvm

58 59
        # NOTE: To add more symbols, you only need to change the following lists
        # nnvm symbol -> topi compute
60
        self.symbol2topi = {
61 62
            nnvm.sym.conv2d: [topi.nn.conv2d, topi.nn.depthwise_conv2d_nchw,
                              topi.nn.group_conv2d_nchw],
63 64
            nnvm.sym.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
            nnvm.sym.dense: [topi.nn.dense],
65 66
        }

67
        # topi compute -> autotvm task name
68 69 70
        self.topi_to_task = {
            topi.nn.conv2d: "topi_nn_conv2d",
            topi.nn.depthwise_conv2d_nchw: "topi_nn_depthwise_conv2d_nchw",
71
            topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
72
            topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
73
            topi.nn.dense: "topi_nn_dense",
74 75
        }

76 77 78 79 80
        self.topi_to_schedule = {
            topi.nn.conv2d: [topi.generic.schedule_conv2d_nchw,
                             topi.generic.schedule_conv2d_nhwc],
            topi.nn.depthwise_conv2d_nchw: [topi.generic.schedule_depthwise_conv2d_nchw,
                                            topi.generic.schedule_depthwise_conv2d_nhwc],
81
            topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
82 83 84 85 86
            topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
            topi.nn.dense: [topi.generic.schedule_dense],
        }

        self._register_tracing()
87 88
        self._register_topi_task()
        self.task_collection = []
89 90 91 92 93 94 95 96
        self.wanted_topi_funcs = list(self.topi_to_task.keys())

    def _register_tracing(self):
        """Register tracing function to track the topi function call"""
        # register topi compute for "tracing" target
        for topi_compute in self.topi_to_task:
            def _local_scope(compute_func):
                """start a scope to hold the local function in for loop"""
97

98 99
                @compute_func.register("tracing", )
                def _tracing_topi_compute(*args, **kwargs):
100 101 102 103
                    assert not kwargs, "Do not support extracting tuning tasks when" \
                                       "kwargs is used in TOPI function call." \
                                       "Please modify it to use only positional args."

104 105 106 107 108 109 110 111 112 113 114 115 116
                    if compute_func in self.wanted_topi_funcs:  # record this call
                        key = (self.topi_to_task[compute_func], serialize_args(args))
                        if key not in self.task_collection:
                            self.task_collection.append(key)

                    return compute_func.fdefault(*args)
            _local_scope(topi_compute)

        # register topi schedule for "tracing" target
        for topi_compute in self.topi_to_task:
            for topi_schedule in self.topi_to_schedule[topi_compute]:
                def _local_scope_(schedule_func):
                    """start a scope to hold the local function in for loop"""
117

118 119 120 121 122
                    @schedule_func.register("tracing", )
                    def _tracing_topi_compute(outs):
                        outs = [outs] if isinstance(outs, tensor.Tensor) else outs
                        return create_schedule([x.op for x in outs])
                _local_scope_(topi_schedule)
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148

    def _register_topi_task(self):
        """register tuning wrapper for topi function"""
        import topi

        # Tuning wrapper for topi functions
        @register("topi_nn_conv2d")
        def _topi_nn_conv2d(*args, **kwargs):
            assert not kwargs, "Do not support kwargs in template function call"
            args = deserialize_args(args)
            A, W = args[:2]
            layout = args[-2]
            assert layout == 'NCHW', "only support NCHW currently"
            C = topi.nn.conv2d(*args, **kwargs)
            s = topi.generic.schedule_conv2d_nchw([C])
            return s, [A, W, C]

        @register("topi_nn_depthwise_conv2d_nchw")
        def _topi_nn_depthwise_conv2d_nchw(*args, **kwargs):
            assert not kwargs, "Do not support kwargs in template function call"
            args = deserialize_args(args)
            A, W = args[:2]
            C = topi.nn.depthwise_conv2d_nchw(*args, **kwargs)
            s = topi.generic.schedule_depthwise_conv2d_nchw([C])
            return s, [A, W, C]

149 150 151 152 153 154 155 156 157
        @register("topi_nn_group_conv2d_nchw")
        def _topi_nn_group_conv2d_nchw(*args, **kwargs):
            assert not kwargs, "Do not support kwargs in template function call"
            args = deserialize_args(args)
            A, W = args[:2]
            C = topi.nn.group_conv2d_nchw(*args, **kwargs)
            s = topi.generic.schedule_group_conv2d_nchw([C])
            return s, [A, W, C]

158 159 160 161 162 163 164 165 166
        @register("topi_nn_conv2d_transpose_nchw")
        def _topi_nn_conv2d_transpose_nchw(*args, **kwargs):
            assert not kwargs, "Do not support kwargs in template function call"
            args = deserialize_args(args)
            A, W = args[:2]
            C = topi.nn.conv2d_transpose_nchw(*args, **kwargs)
            s = topi.generic.schedule_conv2d_transpose_nchw([C])
            return s, [A, W, C]

167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
        @register("topi_nn_dense")
        def _topi_nn_dense(*args, **kwargs):
            assert not kwargs, "Do not support kwargs in template function call"
            args = deserialize_args(args)
            data, weight, bias = args
            C = topi.nn.dense(*args, **kwargs)
            s = topi.generic.schedule_dense([C])
            if bias is not None:
                return s, [data, weight, bias, C]
            return s, [data, weight, C]

    def reset(self, wanted_topi_funcs):
        """Reset task collections

        Parameters
        ----------
        wanted_topi_funcs: List of function
            The topi function to be extracted
        """
186
        self.task_collection = []
187
        self.wanted_topi_funcs = wanted_topi_funcs
188 189

    def get_tasks(self):
190 191 192 193 194 195 196
        """Get collected tasks

        Returns
        -------
        tasks: List of tuple(name, args)
            A list of tasks extracted from the nnvm graph
        """
197 198 199 200
        return self.task_collection

    @staticmethod
    def get():
201 202 203 204 205 206 207
        """Get the single instance of TaskExtractEnv

        Returns
        -------
        env: TaskExtractEnv
            The single instance of TaskExtractEnv
        """
208 209 210 211 212 213 214 215
        if not TaskExtractEnv.current:
            TaskExtractEnv.current = TaskExtractEnv()
        return TaskExtractEnv.current


def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
    """ Extract tuning tasks from a nnvm graph.

216 217
    This function collects tuning tasks by building the graph
    with a "tracing" target and tracing all the calls to topi.
218 219 220 221 222

    Parameters
    ----------
    graph : Graph
        The graph to tune
223
    shape : dict of str to tuple
224 225 226 227 228 229
        The input shape to the graph
    dtype : str or dict of str to str
        The input types to the graph
    target: tvm.target.Target
        The compilation target
    symbols : Array of nnvm.symbol
230
        Array of nnvm symbols want to be tuned
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
    target_host: tvm.target.Target
        The host compilation target

    Returns
    -------
    task: Array of autotvm.task.Task
        collected tasks
    """
    import nnvm.compiler

    env = TaskExtractEnv.get()

    topi_funcs = []
    for sym_name in symbols:
        if sym_name in env.symbol2topi:
            topi_funcs.extend(env.symbol2topi[sym_name])
        else:
            warnings.warn("Symbol %s is not tunable, ignored" % sym_name)

    # run compiler to collect all TOPI calls during compilation
251
    env.reset(topi_funcs)
252 253 254 255 256

    # disable logger temporarily
    old_state = logger.disabled
    logger.disabled = True

257 258 259 260
    # use a "tracing" target to do a fake compile for collecting topi calls
    tracing_target = _target.create("llvm -device=tracing")
    nnvm.compiler.engine.clear_cache()
    nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)
261 262

    logger.disabled = old_state
263

264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
    # create tasks for target
    tasks = []
    for task_name, args in env.get_tasks():
        tasks.append(create(task_name, args,
                            target=target, target_host=target_host,
                            template_key='direct'))

    return tasks


def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_host=None):
    """ Extract tuning tasks from multiple nnvm graphs.

    This function is the multiple graph version of extract_from_graph

    Parameters
    ----------
    graphs : List of Graph
        The list of graphs to tune
    shapes : List of dict of str to tuple
        The input shape to the graph
    dtypes : List of str or dict of str to str
        The input types to the graph
    target: tvm.target.Target
        The compilation target
    symbols : Array of nnvm.symbol
        Array of nnvm symbols want to be tuned
    target_host: tvm.target.Target
        The host compilation target

    Returns
    -------
    task: Array of autotvm.task.Task
        collected tasks
    """
    import nnvm.compiler

    env = TaskExtractEnv.get()

    topi_funcs = []
    for sym_name in symbols:
        if sym_name in env.symbol2topi:
            topi_funcs.extend(env.symbol2topi[sym_name])
        else:
            warnings.warn("Symbol %s is not tunable, ignored" % sym_name)

    # run compiler to collect all TOPI calls during compilation
    env.reset(topi_funcs)

    # disable logger temporarily
    old_state = logger.disabled
    logger.disabled = True

    # use a "tracing" target to do a fake compile for collecting topi calls
    tracing_target = _target.create("llvm -device=tracing")

    nnvm.compiler.engine.clear_cache()
    for graph, shape, dtype in zip(graphs, shapes, dtypes):
        nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype)

    logger.disabled = old_state

    # create tasks for target
327 328 329 330 331 332 333
    tasks = []
    for task_name, args in env.get_tasks():
        tasks.append(create(task_name, args,
                            target=target, target_host=target_host,
                            template_key='direct'))

    return tasks