target.py 9.14 KB
Newer Older
1
"""Target management API of TVM.
2

3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
TVM's target string is in fomat ``<target_name> [-option=value]...``.

Note
----
The list of options include:

- **-device=<device name>**

   The device name.

- **-mtriple=<target triple>** or **-target**

   Specify the target triple, which is useful for cross
   compilation.

- **-mcpu=<cpuname>**

   Specify a specific chip in the current architecture to
   generate code for. By default this is infered from the
   target triple and autodetected to the current architecture.

- **-mattr=a1,+a2,-a3,...**

   Override or control specific attributes of the target,
   such as whether SIMD operations are enabled or not. The
   default set of attributes is set by the current CPU.

- **-system-lib**

   Build TVM system library module. System lib is a global module that contains
   self registered functions in program startup. User can get the module using
   :any:`tvm.module.system_lib`.
   It is useful in environments where dynamic loading api like dlopen is banned.
   The system lib will be available as long as the result code is linked by the program.

We can use :any:`tvm.target.create` to create a tvm.target.Target from the target string.
We can also use other specific function in this module to create specific targets.
"""
41 42
from __future__ import absolute_import

43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
import warnings
from ._ffi.base import _LIB_NAME

try:
    from decorator import decorate
except ImportError as err_msg:
    # Allow decorator to be missing in runtime
    if _LIB_NAME != "libtvm_runtime.so":
        raise err_msg


def _merge_opts(opts, new_opts):
    """Helper function to merge options"""
    if isinstance(new_opts, str):
        new_opts = new_opts.split()
    if new_opts:
59 60
        opt_set = set(opts)
        new_opts = [opt for opt in new_opts if opt not in opt_set]
61 62 63 64
        return opts + new_opts
    return opts


65
class Target(object):
66
    """Target device information, use through TVM API.
67

68 69
    Parameters
    ----------
70
    target_name : {"llvm", "cuda", "opencl", "metal", "rocm", "stackvm", "opengl", "ext_dev"}
71
        The major target name.
72

73 74
    options : list of str, optional
        Additional arguments appended to the target.
75

76 77 78 79 80 81 82 83
    Note
    ----
    Do not use class constructor, you can create target using the following functions

    - :any:`tvm.target.create` create target from string
    - :any:`tvm.target.rasp` create raspberry pi target
    - :any:`tvm.target.cuda` create CUDA target
    - :any:`tvm.target.rocm` create ROCM target
84
    - :any:`tvm.target.mali` create Mali target
85 86 87 88 89 90 91
    """
    current = None

    def __init__(self,
                 target_name,
                 options=None):
        self.target_name = target_name
92
        self.options = _merge_opts([], options)
93
        self.device_name = ""
94
        self.libs = []
95
        # Parse device option
96
        for item in self.options:
97
            if item.startswith("-libs="):
98 99
                libs = item.split("=")[1]
                self.libs += libs.split(",")
100
            elif item.startswith("-device="):
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
                self.device_name = item.split("=")[1]
        # Target query searchs device name first
        if self.device_name:
            self.keys = (self.device_name,)
        else:
            self.keys = ()
        # Target configuration handling
        self.thread_warp_size = 1
        if target_name in ("llvm", ):
            self.keys += ("cpu",)
        elif target_name in ("cuda", "nvptx"):
            self.keys += ("cuda", "gpu")
            self.max_num_threads = 512
            self.thread_warp_size = 32
        elif target_name in ("rocm", "opencl"):
            # For now assume rocm schedule for opencl
            self.keys += ("rocm", "gpu")
            self.max_num_threads = 256
119
        elif target_name in ("metal", "vulkan"):
120 121
            self.keys += ("gpu",)
            self.max_num_threads = 256
122 123
        elif target_name in ("opengl",):
            self.keys += ("opengl",)
124 125 126 127 128
        elif target_name in ("stackvm", "ext_dev"):
            # Do not now class for stacvm or ext_dev
            pass
        else:
            raise ValueError("Unknown target name %s" % target_name)
129 130

    def __str__(self):
131
        return " ".join([self.target_name] + self.options)
132 133 134 135 136

    def __repr__(self):
        return self.__str__()

    def __enter__(self):
137 138 139 140 141 142
        self._old_target = Target.current
        if self._old_target is not None and str(self) != str(self._old_target):
            warnings.warn(
                "Override target '%s' with new target scope '%s'" % (
                    self._old_target, self))
        Target.current = self
143 144 145
        return self

    def __exit__(self, ptype, value, trace):
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
        Target.current = self._old_target


def generic_func(fdefault):
    """Wrap a target generic function.

    Generic function allows registeration of further functions
    that can be dispatched on current target context.
    If no registered dispatch is matched, the fdefault will be called.

    Parameters
    ----------
    fdefault : function
        The default function.

    Returns
    -------
    fgeneric : function
        A wrapped generic function.

    Example
    -------
    .. code-block:: python

      import tvm
      # wrap function as target generic
      @tvm.target.generic_func
      def my_func(a):
          return a + 1
      # register specialization of my_func under target cuda
      @my_func.register("cuda")
      def my_func_cuda(a):
          return a + 2
      # displays 3, because my_func is called
      print(my_func(2))
      # displays 4, because my_func_cuda is called
      with tvm.target.cuda():
          print(my_func(2))
    """
    dispatch_dict = {}
    func_name = fdefault.__name__

    def register(key, func=None, override=False):
        """Register function to be the dispatch function.

        Parameters
        ----------
        key : str or list of str
            The key to be registered.

        func : function
            The function to be registered.

        override : bool
            Whether override existing registeration.

        Returns
        -------
        The register function is necessary.
        """
        def _do_reg(myf):
            key_list = [key] if isinstance(key, str) else key
            for k in key_list:
                if k in dispatch_dict and not override:
                    raise ValueError(
                        "Key is already registered for %s" % func_name)
                dispatch_dict[k] = myf
            return myf
        if func:
            return _do_reg(myf)
        return _do_reg

    def dispatch_func(func, *args, **kwargs):
        """The wrapped dispath function"""
        target = current_target()
        if target is None:
            return func(*args, **kwargs)
        for k in target.keys:
            if k in dispatch_dict:
                return dispatch_dict[k](*args, **kwargs)
        return func(*args, **kwargs)
    fdecorate = decorate(fdefault, dispatch_func)
    fdecorate.register = register
    return fdecorate


def cuda(options=None):
    """Returns a cuda target.

    Parameters
    ----------
    options : list of str
        Additional options
    """
    return Target("cuda", options)


def rocm(options=None):
    """Returns a ROCM target.

    Parameters
    ----------
    options : list of str
        Additional options
    """
    return Target("rocm", options)


def rasp(options=None):
    """Returns a rasp target.

    Parameters
    ----------
    options : list of str
        Additional options
    """
    opts = ["-device=rasp",
            "-mtriple=armv7l-none-linux-gnueabihf",
            "-mcpu=cortex-a53",
            "-mattr=+neon"]
    opts = _merge_opts(opts, options)
    return Target("llvm", opts)


270 271 272 273 274 275 276 277 278 279 280 281 282
def mali(options=None):
    """Returns a ARM Mali GPU target.

    Parameters
    ----------
    options : list of str
        Additional options
    """
    opts = ["-device=mali"]
    opts = _merge_opts(opts, options)
    return Target("opencl", opts)


283 284 285 286 287 288 289 290 291 292 293 294
def create(target_str):
    """Get a target given target string.

    Parameters
    ----------
    target_str : str
        The target string.

    Returns
    -------
    target : Target
        The target object
295

296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
    Note
    ----
    See the note on :any:`tvm.target` on target string format.
    """
    if isinstance(target_str, Target):
        return target_str
    if not isinstance(target_str, str):
        raise ValueError("target_str has to be string type")
    arr = target_str.split()
    # Parse device option
    device_name = ""
    for item in arr[1:]:
        if item.startswith("-device="):
            device_name = item.split("=")[1]
    if device_name == "rasp":
        return rasp(arr[1:])
312 313
    if device_name == "mali":
        return mali(arr[1:])
314
    return Target(arr[0], arr[1:])
315 316


317 318
def current_target(allow_none=True):
    """Returns the current target.
319

320 321 322 323
    Parameters
    ----------
    allow_none : bool
       Whether allow the current target to be none
324

325 326 327 328 329 330 331 332 333 334 335
    Raises
    ------
    ValueError if current target is not set.
    """
    if Target.current:
        return Target.current
    if not allow_none:
        raise RuntimeError(
            "Requires a current target in generic function, but it is not set. "
            "Please set it using `with TargetObject:`")
    return Target.current