runtime.py 4.21 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
"""Intrinsics of TVM-Python Hybrid Script for Python emulation runtime"""
18 19

import numpy
20
from tvm import target
21

22 23 24 25 26

class bind(object): #pylint: disable=invalid-name
    """GPU bind software emulataion runtime."""
    def __init__(self, _, ext):
        self.ext = ext
27 28 29 30

    def __iter__(self):
        i = 0
        while i < self.ext:
31
            yield i
32 33 34
            i += 1


35
def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-argument
36 37 38 39 40 41 42 43
    """Allocate a buffer with given shape

    Parameters
    ----------
    shape: Tuple
        The shape of the tensor to be allocated
    dtype: string
        The data type of the tensor
44 45
    scope: string
        The storage scope of the tensor
46 47 48 49 50 51 52 53 54

    Returns
    -------
    tensor: numpy.array
        The tensor allocated
    """
    return numpy.zeros(shape).astype(dtype)


55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
def rsqrt(x):
    """
    Computes reciprocal of square root of x element-wise

    Parameters
    ----------
    x: Tensor

    Returns
    -------
    res: Tensor
        The result of reciprocal of square root of x
    """
    return numpy.ones_like(x) / numpy.sqrt(x)


71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
def popcount(x):
    """
    Count ones in the binary representation of number x

    Parameters
    ----------
    x: Integer
        The number to be counted

    Returns
    -------
    cnt: Integer
        The number of ones in the binary representation of number x
    """
    cnt = 0
    while x:
        x -= x & -x
        cnt += 1
    return cnt


def sigmoid(x):
    """
    Sigmoid function of x, aka 1/(1+exp(-x)).

    Parameters
    ----------
    x: a real number

    Returns
    -------
    res: a real number
        The result of sigmoid function
    """
    return 1 / (1 + numpy.exp(-x))


108 109
def max_num_threads(allow_none=True):
    """Get max number of threads for GPU targets."""
110
    return target.Target.current(allow_none).max_num_threads
111 112


113
HYBRID_GLOBALS = {
114 115 116 117 118 119 120 121
    'unroll'         : range,
    'vectorize'      : range,
    'parallel'       : range,
    'const_range'    : range,
    'bind'           : bind,
    'allocate'       : allocate,
    'output_tensor'  : allocate,
    'sqrt'           : numpy.sqrt,
122
    'rsqrt'          : rsqrt,
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    'log'            : numpy.log,
    'tanh'           : numpy.tanh,
    'power'          : numpy.power,
    'exp'            : numpy.exp,
    'sigmoid'        : sigmoid,
    'popcount'       : popcount,
    'likely'         : lambda cond: cond,
    'uint8'          : numpy.uint8,
    'uint16'         : numpy.uint16,
    'uint32'         : numpy.uint32,
    'uint64'         : numpy.uint64,
    'int8'           : numpy.int8,
    'int16'          : numpy.int16,
    'int32'          : numpy.int32,
    'int64'          : numpy.int64,
    'float16'        : numpy.float16,
    'float32'        : numpy.float32,
    'float64'        : numpy.float64,
    'ceil_div'       : lambda a, b: (a + b - 1) // b,
    'max_num_threads': max_num_threads
143
}
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163


def _enter_hybrid_runtime(func):
    """Put hybrid runtime variables into the global scope"""
    _globals = func.__globals__
    intersect = []
    for elem in list(HYBRID_GLOBALS.keys()):
        if elem in _globals.keys():
            intersect.append((elem, _globals[elem]))
        _globals[elem] = HYBRID_GLOBALS[elem]
    return intersect


def _restore_runtime(func, intersect):
    """Rollback the modification caused by hybrid runtime"""
    _globals = func.__globals__
    for elem in list(HYBRID_GLOBALS.keys()):
        _globals.pop(elem)
    for k, v in intersect:
        _globals[k] = v