tag.py 2.79 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
"""Tag class for TVM operators."""
18
import warnings
19
from ._ffi.base import decorate
20 21 22 23 24

class TagScope(object):
    """Tag scope object to set tag for operators, working as context
    manager and decorator both. See also tag_scope.
    """
25 26 27 28 29 30 31 32
    _current = None

    @classmethod
    def get_current(cls):
        if cls._current:
            cls._current.accessed = True
        return cls._current

33 34 35
    def __init__(self, tag):
        self._old_scope = None
        self.tag = tag
36
        self.accessed = False
37 38

    def __enter__(self):
39
        if TagScope._current is not None:
40
            raise ValueError("nested op_tag is not allowed for now")
41 42
        self._old_scope = TagScope._current
        TagScope._current = self
43 44 45 46
        return self

    def __exit__(self, ptype, value, trace):
        assert self._old_scope is None
47 48 49
        if not self.accessed:
            warnings.warn("Tag '%s' declared via TagScope was not used." % (self.tag,))
        TagScope._current = self._old_scope
50 51

    def __call__(self, fdecl):
52
        def tagged_fdecl(func, *args, **kwargs):
53
            with self:
54 55
                return func(*args, **kwargs)
        return decorate(fdecl, tagged_fdecl)
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77


def tag_scope(tag):
    """The operator tag scope.

    Parameters
    ----------
    tag: str
        The tag name.

    Returns
    -------
    tag_scope: TagScope
        The tag scope object, which can be used as decorator or
        context manger.

    Example
    -------
    .. code-block:: python

        n = tvm.var('n')
        m = tvm.var('m')
78
        l = tvm.var('l')
79
        A = tvm.placeholder((n, l), name='A')
80
        B = tvm.placeholder((m, l), name='B')
81 82 83 84 85 86 87 88 89 90 91
        k = tvm.reduce_axis((0, l), name='k')

        with tvm.tag_scope(tag='matmul'):
            C = tvm.compute((n, m), lambda i, j: tvm.sum(A[i, k] * B[j, k], axis=k))

        # or use tag_scope as decorator
        @tvm.tag_scope(tag="conv")
        def compute_relu(data):
            return tvm.compute(data.shape, lambda *i: tvm.select(data(*i) < 0, 0.0, data(*i)))
    """
    return TagScope(tag)