target.py 15.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
"""Target management API of TVM.
18

19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
TVM's target string is in fomat ``<target_name> [-option=value]...``.

Note
----
The list of options include:

- **-device=<device name>**

   The device name.

- **-mtriple=<target triple>** or **-target**

   Specify the target triple, which is useful for cross
   compilation.

- **-mcpu=<cpuname>**

   Specify a specific chip in the current architecture to
   generate code for. By default this is infered from the
   target triple and autodetected to the current architecture.

- **-mattr=a1,+a2,-a3,...**

   Override or control specific attributes of the target,
   such as whether SIMD operations are enabled or not. The
   default set of attributes is set by the current CPU.

- **-system-lib**

   Build TVM system library module. System lib is a global module that contains
   self registered functions in program startup. User can get the module using
   :any:`tvm.module.system_lib`.
   It is useful in environments where dynamic loading api like dlopen is banned.
   The system lib will be available as long as the result code is linked by the program.

We can use :any:`tvm.target.create` to create a tvm.target.Target from the target string.
We can also use other specific function in this module to create specific targets.
"""
57 58
from __future__ import absolute_import

59 60
import warnings

61
from ._ffi.base import _LIB_NAME
62 63
from ._ffi.node import NodeBase, register_node
from . import _api_internal
64 65 66 67 68 69 70 71 72 73 74 75 76

try:
    from decorator import decorate
except ImportError as err_msg:
    # Allow decorator to be missing in runtime
    if _LIB_NAME != "libtvm_runtime.so":
        raise err_msg

def _merge_opts(opts, new_opts):
    """Helper function to merge options"""
    if isinstance(new_opts, str):
        new_opts = new_opts.split()
    if new_opts:
77 78
        opt_set = set(opts)
        new_opts = [opt for opt in new_opts if opt not in opt_set]
79 80 81 82
        return opts + new_opts
    return opts


83 84
@register_node
class Target(NodeBase):
85
    """Target device information, use through TVM API.
86

87 88 89 90 91
    Note
    ----
    Do not use class constructor, you can create target using the following functions

    - :any:`tvm.target.create` create target from string
92
    - :any:`tvm.target.arm_cpu` create arm_cpu target
93 94
    - :any:`tvm.target.cuda` create CUDA target
    - :any:`tvm.target.rocm` create ROCM target
95
    - :any:`tvm.target.mali` create Mali target
96
    - :any:`tvm.target.intel_graphics` create Intel Graphics target
97
    """
98 99 100 101 102 103 104
    def __new__(cls):
        # Always override new to enable class
        obj = NodeBase.__new__(cls)
        obj._keys = None
        obj._options = None
        obj._libs = None
        return obj
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122

    @property
    def keys(self):
        if not self._keys:
            self._keys = [k.value for k in self.keys_array]
        return self._keys

    @property
    def options(self):
        if not self._options:
            self._options = [o.value for o in self.options_array]
        return self._options

    @property
    def libs(self):
        if not self._libs:
            self._libs = [l.value for l in self.libs_array]
        return self._libs
123

124 125 126 127 128 129 130
    @property
    def model(self):
        for opt in self.options_array:
            if opt.value.startswith('-model='):
                return opt.value[7:]
        return 'unknown'

131
    def __enter__(self):
132
        _api_internal._EnterTargetScope(self)
133 134 135
        return self

    def __exit__(self, ptype, value, trace):
136
        _api_internal._ExitTargetScope(self)
137

138

139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
@register_node
class GenericFunc(NodeBase):
    """GenericFunc node reference. This represents a generic function
    that may be specialized for different targets. When this object is
    called, a specialization is chosen based on the current target.

    Note
    ----
    Do not construct an instance of this object, it should only ever be
    used as a return value from calling into C++.
    """
    def __call__(self, *args):
        return _api_internal._GenericFuncCallFunc(self, *args)

    def set_default(self, func, allow_override=False):
        """Set the default function to be used if no specializations match
        the current target.

        Parameters
        ----------
        func : function
            The default function

        allow_override : bool
            Whether to allow the current default to be overridden
        """
        _api_internal._GenericFuncSetDefault(self, func, allow_override)

    def register(self, func, key_list, allow_override=False):
        """Register a specialization for this GenericFunc.

        Parameters
        ----------
        func : function
            The function to be registered.

        key : str or list of str
            The key to be registered.

        allow_override : bool, optional
            Whether to allow existing keys to be overridden.
        """
        key_list = [key_list] if isinstance(key_list, str) else key_list
        _api_internal._GenericFuncRegisterFunc(self, func, key_list, allow_override)

184

185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
def get_native_generic_func(name):
    """Get a generic function from the global registry. If no
    function is registered under the given name, a new generic
    function is created.

    Parameters
    ----------
    name : string
        The name of the generic function to get

    Returns
    -------
    func : GenericFunc
        The generic function for the given name
    """
    return _api_internal._GenericFuncGetGlobal(name)

202

203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
def override_native_generic_func(func_name):
    """Override a generic function defined in C++

    Generic function allows registration of further functions
    that can be dispatched on current target context.
    If no registered dispatch is matched, the fdefault will be called.

    Parameters
    ----------
    func_name : string
        The name of the generic func to be overridden

    Returns
    -------
    fgeneric : function
        A wrapped generic function.

    Example
    -------
    .. code-block:: python

224 225 226 227 228 229 230 231 232 233 234 235 236 237
      import tvm
      # wrap function as target generic
      @tvm.target.override_native_generic_func("my_func")
      def my_func(a):
          return a + 1
      # register specialization of my_func under target cuda
      @my_func.register("cuda")
      def my_func_cuda(a):
          return a + 2
      # displays 3, because my_func is called
      print(my_func(2))
      # displays 4, because my_func_cuda is called
      with tvm.target.cuda():
          print(my_func(2))
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
    """
    generic_func_node = get_native_generic_func(func_name)

    def fdecorate(fdefault):
        """Wrap a target generic function, overriding the previous
        default that was set for the generic function.

        Parameters
        ----------
        fdefault : function
            The default function.

        Returns
        -------
        fgeneric : function
            A wrapped generic function.
254

255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
        """
        generic_func_node.set_default(fdefault, allow_override=True)

        def register(key, func=None, override=True):
            """Register function to be the dispatch function.

            Parameters
            ----------
            key : str or list of str
                The key to be registered.

            func : function
                The function to be registered.

            override : bool, optional
                Whether override existing registration.

            Returns
            -------
            The register function is necessary.
            """
            def _do_reg(myf):
                generic_func_node.register(myf, key, override)
                return myf
            if func:
                return _do_reg(func)
            return _do_reg

        def dispatch_func(func, *args, **kwargs):
            #pylint: disable=unused-argument
            """The wrapped dispath function"""
            if kwargs:
                raise RuntimeError(
                    "Keyword arguments cannot be used when invoking generic_func %s" % func_name)
            return generic_func_node(*args)
        fresult = decorate(fdefault, dispatch_func)
291
        fresult.fdefault = fdefault
292 293 294
        fresult.register = register
        return fresult
    return fdecorate
295 296 297 298

def generic_func(fdefault):
    """Wrap a target generic function.

299
    Generic function allows registration of further functions
300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
    that can be dispatched on current target context.
    If no registered dispatch is matched, the fdefault will be called.

    Parameters
    ----------
    fdefault : function
        The default function.

    Returns
    -------
    fgeneric : function
        A wrapped generic function.

    Example
    -------
    .. code-block:: python

      import tvm
      # wrap function as target generic
      @tvm.target.generic_func
      def my_func(a):
          return a + 1
      # register specialization of my_func under target cuda
      @my_func.register("cuda")
      def my_func_cuda(a):
          return a + 2
      # displays 3, because my_func is called
      print(my_func(2))
      # displays 4, because my_func_cuda is called
      with tvm.target.cuda():
          print(my_func(2))
    """
    dispatch_dict = {}
    func_name = fdefault.__name__

    def register(key, func=None, override=False):
        """Register function to be the dispatch function.

        Parameters
        ----------
        key : str or list of str
            The key to be registered.

        func : function
            The function to be registered.

        override : bool
347
            Whether override existing registration.
348 349 350 351 352 353 354 355 356 357 358 359 360 361

        Returns
        -------
        The register function is necessary.
        """
        def _do_reg(myf):
            key_list = [key] if isinstance(key, str) else key
            for k in key_list:
                if k in dispatch_dict and not override:
                    raise ValueError(
                        "Key is already registered for %s" % func_name)
                dispatch_dict[k] = myf
            return myf
        if func:
362
            return _do_reg(func)
363 364 365 366 367 368 369 370 371 372 373 374 375
        return _do_reg

    def dispatch_func(func, *args, **kwargs):
        """The wrapped dispath function"""
        target = current_target()
        if target is None:
            return func(*args, **kwargs)
        for k in target.keys:
            if k in dispatch_dict:
                return dispatch_dict[k](*args, **kwargs)
        return func(*args, **kwargs)
    fdecorate = decorate(fdefault, dispatch_func)
    fdecorate.register = register
376
    fdecorate.fdefault = fdefault
377 378
    return fdecorate

379

380
def cuda(model='unknown', options=None):
381 382 383 384
    """Returns a cuda target.

    Parameters
    ----------
385 386
    model: str
        The model of cuda device (e.g. 1080ti)
387
    options : str or list of str
388 389
        Additional options
    """
390 391
    opts = _merge_opts(['-model=%s' % model], options)
    return _api_internal._TargetCreate("cuda", *opts)
392 393


394
def rocm(model='unknown', options=None):
395 396 397 398
    """Returns a ROCM target.

    Parameters
    ----------
399 400
    model: str
        The model of this device
401
    options : str or list of str
402 403
        Additional options
    """
404 405
    opts = _merge_opts(["-model=%s" % model], options)
    return _api_internal._TargetCreate("rocm", *opts)
406 407


408
def mali(model='unknown', options=None):
409 410 411 412
    """Returns a ARM Mali GPU target.

    Parameters
    ----------
413 414
    model: str
        The model of this device
415
    options : str or list of str
416 417
        Additional options
    """
418
    opts = ["-device=mali", '-model=%s' % model]
419
    opts = _merge_opts(opts, options)
420
    return _api_internal._TargetCreate("opencl", *opts)
421 422


423
def intel_graphics(model='unknown', options=None):
424
    """Returns an Intel Graphics target.
425 426 427

    Parameters
    ----------
428 429
    model: str
        The model of this device
430 431 432
    options : str or list of str
        Additional options
    """
433
    opts = ["-device=intel_graphics", '-model=%s' % model]
434 435 436 437
    opts = _merge_opts(opts, options)
    return _api_internal._TargetCreate("opencl", *opts)


438
def opengl(model='unknown', options=None):
439 440 441 442
    """Returns a OpenGL target.

    Parameters
    ----------
443
    options : str or list of str
444 445
        Additional options
    """
446 447
    opts = _merge_opts(["-model=%s" % model], options)
    return _api_internal._TargetCreate("opengl", *opts)
448 449


450 451 452 453 454 455 456 457 458 459 460 461
def arm_cpu(model='unknown', options=None):
    """Returns a ARM CPU target.
    This function will also download pre-tuned op parameters when there is none.

    Parameters
    ----------
    model: str
        SoC name or phone name of the arm board.
    options : str or list of str
        Additional options
    """
    trans_table = {
462 463 464 465 466 467 468 469
        "pixel2":    ["-model=snapdragon835", "-target=arm64-linux-android -mattr=+neon"],
        "mate10":    ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
        "mate10pro": ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
        "p20":       ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
        "p20pro":    ["-model=kirin970", "-target=arm64-linux-android -mattr=+neon"],
        "rasp3b":    ["-model=bcm2837", "-target=armv7l-linux-gnueabihf -mattr=+neon"],
        "rk3399":    ["-model=rk3399", "-target=aarch64-linux-gnu -mattr=+neon"],
        "pynq":      ["-model=pynq", "-target=armv7a-linux-eabi -mattr=+neon"],
470
        "ultra96":   ["-model=ultra96", "-target=aarch64-linux-gnu -mattr=+neon"],
471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
    }
    pre_defined_opt = trans_table.get(model, ["-model=%s" % model])

    opts = ["-device=arm_cpu"] + pre_defined_opt
    opts = _merge_opts(opts, options)
    return _api_internal._TargetCreate("llvm", *opts)


def rasp(options=None):
    """Return a Raspberry 3b target.

    Parameters
    ----------
    options : str or list of str
        Additional options
    """
    warnings.warn('tvm.target.rasp() is going to be deprecated. '
                  'Please use tvm.target.arm_cpu("rasp3b")')
    return arm_cpu('rasp3b', options)


492 493 494 495 496 497 498
def vta(model='unknown', options=None):
    opts = ["-device=vta", '-keys=cpu', '-model=%s' % model]
    opts = _merge_opts(opts, options)
    ret = _api_internal._TargetCreate("ext_dev", *opts)
    return ret


499 500 501 502 503 504 505 506 507 508 509 510
def create(target_str):
    """Get a target given target string.

    Parameters
    ----------
    target_str : str
        The target string.

    Returns
    -------
    target : Target
        The target object
511

512 513 514 515 516 517 518 519
    Note
    ----
    See the note on :any:`tvm.target` on target string format.
    """
    if isinstance(target_str, Target):
        return target_str
    if not isinstance(target_str, str):
        raise ValueError("target_str has to be string type")
520 521

    return _api_internal._TargetFromString(target_str)
522 523


524 525
def current_target(allow_none=True):
    """Returns the current target.
526

527 528 529 530
    Parameters
    ----------
    allow_none : bool
       Whether allow the current target to be none
531

532 533 534 535
    Raises
    ------
    ValueError if current target is not set.
    """
536
    return _api_internal._GetCurrentTarget(allow_none)