hybrid_script.rst 6.42 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
.. _hybrid-langref-label:

Hybrid Frontend Language Reference
==================================

Overview
--------

This hybrid frontend allows users to write preliminary versions of some idioms that yet have
been supported by TVM officially.

Features
--------

Software Emulation
~~~~~~~~~~~~~~~~~~

Both software emulation and compilation are supported. To define a function,
you need to use ``tvm.hybrid.script`` decorator to indicate this is a hybrid function:

.. code-block:: python

    @tvm.hybrid.script
    def outer_product(a, b, c):
25
        c = output_tensor((100, 99), 'float32')
26 27 28
        for i in range(a.shape[0]):
            for j in range(b.shape[0]):
                c[i, j] = a[i] * b[j]
29 30 31 32 33
          return c
    a = numpy.random.randn(100)
    b = numpy.random.randn(99)
    c = outer_product(a, b)

34 35 36 37 38 39 40 41 42 43 44

This decorator will import `Keywords`_ required spontaneously when software emulation.
After software emulation is done, the imported keywords will be cleaned up. Users do not need
worry about keyword conflict and pollution.

Every element passed for software emulation in the argument list is either a python variable
or ``numpy`` numeric type.

Backend Compilation
~~~~~~~~~~~~~~~~~~~

45
This function is not encouraged to use, users are encouraged to use the second interface.
46 47 48 49 50 51
The current parse interface looks like:

.. code-block:: python

   a = tvm.placeholder((100, ), name='a')
   b = tvm.placeholder((99, ), name='b')
52
   parser = tvm.hybrid.parse(outer_product, [a, b]) # return the parser of this function
53 54


55 56
If we pass these tvm data structures, like ``Tensor``, ``Var``, ``Expr.*Imm``,
or ``tvm.container.Array``, to this function, it returns a op node:
57 58 59 60 61

.. code-block:: python

   a = tvm.placeholder((100, ), name='a')
   b = tvm.placeholder((99, ), name='b')
62 63
   c = outer_product(a, b, c) # return the output tensor(s) of the operator

64 65 66
You can use any methods that can be applied on a TVM ``OpNode``, like create_schedule, although
so far, the functionality of schedule is as limited as ``ExternOpNode``. At least, it can be built
to LLVM module.
67 68 69 70 71 72 73 74

Tuning
~~~~~~

Follow up the example above, you can use some tvm like interfaces to tune the code: 

.. code-block:: python

75
   i, j = c.op.axis
76 77 78 79
   sch = tvm.create_schedule(op)
   jo, ji = sch.split(j, 4)
   sch.vectorize(ji)

80 81 82 83 84 85 86 87
For now, you can use loop annotations (``unroll``, ``parallel``, ``vectorize``, and ``bind``),
loop manipulation (``split`` and ``fuse``), and ``reorder``.

.. note::

        This is a preliminary function, so users should be in charge of the correctness
        of the functionality after tuning. Specifically, users should be careful when
        fusing and reorderding imperfect loops. 
88 89 90 91 92 93 94 95 96 97

Loops
~~~~~

In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel``, and ``vectorized``.

Here we use ``range`` aka ``serial``, ``unroll``, ``parallel``, and ``vectorize``,
these **4** keywords to annotate the corresponding types of for loops.
The the usage is roughly the same as Python standard ``range``.

98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
Besides all the loop types supported in Halide, ``const_range`` is supported for some specific conditions.
Sometimes, ``tvm.container.Array`` is desired to pass as an argument, but in TVM-HalideIR, there is no
such support that converts ``tvm.container.Array`` to an ``Expr``. Thus, a limited feature is supported.
Users can access containers by either constants or constants loops annotated.

.. code-block:: python

   @tvm.hybrid.script
   def foo(a, b): # b is a tvm.container.Array
       c = output_tensor(a.shape, a.dtype)
       for i in const_range(len(a)): # because you have b access, i should be explicitly annotated as const_range
           c[i] = a[i] + b[i]
       return c


113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
Variables
~~~~~~~~~

All the mutatable variables will be lowered to an array with size 1.
It regards the first store of a variable as its declaration.

.. note::

        Unlike conventional Python, in hybrid script, the declared variable
        can only be used in the scope level it is declared.


.. note::

        Currently, you can ONLY use basic-typed variables, i.e. the type of the
        variable should be either ``float32``, or ``int32``.

.. code-block:: python

   for i in range(5):
       s = 0 # declaration, this s will be a 1-array in lowered IR
       for j in range(5):
     	  s += a[i, j] # do something with sum
       b[i] = sum # you can still use sum in this level
   a[0] = s # you CANNOT use s here, even though it is allowed in conventional Python


Attributes
~~~~~~~~~~

143 144 145
So far, ONLY tensors' ``shape`` and ``dtype`` attribute are supported!
The ``shape`` atrribute is essentailly a tuple, so you MUST access it as an array.
Currently, only constant-indexed access is supported.
146 147 148 149 150 151 152 153 154 155 156 157 158 159

.. code-block:: python

   x = a.shape[2] # OK!
   for i in range(3):
      for j in a.shape[i]: # BAD! i is not a constant!
          # do something


Conditional Statement and Expression
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. code-block:: python

160 161 162 163 164
   if condition1 and condition2 and condition3:
       # do something
   else:
       # do something else
   # Select
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
   a = b if condition else c

However, NO ``True`` and ``False`` keyword supported yet.


Math Intrinsics
~~~~~~~~~~~~~~~

So far, these math intrinsics, ``log``, ``exp``, ``sigmoid``,
``tanh``, ``power``, and ``popcount``, are supported.
No import is required, just as it is mentioned in `Software Emulation`_, just use it!

Array Allocation
~~~~~~~~~~~~~~~~

**Under construction, this function will be supported later!**

Use a function call ``allocation(shape, type, share/local)`` to declare an array buffer.
183 184 185
The basic usage is roughly the same as a normal ``numpy.array``, and you should access
high-dim array in ``a[i, j, k]`` fashion instead of ``a[i][j][k]``,
even for ``tvm.container.Array`` for compilation.
186 187 188 189 190 191 192 193 194 195 196 197 198 199


Thread Bind
~~~~~~~~~~~


You can also do loop-thread bind by writing code like this:

.. code-block:: python

   for tx in bind("threadIdx.x", 100):
       a[tx] = b[tx]


200 201 202 203 204 205 206 207 208 209 210 211 212 213
Assert Statement
~~~~~~~~~~~~~~~~

Assert statement is supported, you can simply use it as it is in standard Python.

.. code-block:: python

    assert cond, mesg

.. note::

        ``Assert`` is NOT a function call. Users are encouraged to use assert in the way
        presented above --- condition followed by message. It fits both Python AST and HalideIR.

214 215
Keywords
~~~~~~~~
216
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``, ``const_expr``
217
- Math keywords: ``log``, ``exp``, ``sigmoid``, ``tanh``, ``power``, ``popcount``