gemm_int8.py 5.73 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21
"Example code to perform int8 GEMM"
import logging
import sys
import numpy as np
import tvm
22
from tvm import te
23
from tvm import autotvm
24
from topi.cuda.tensor_intrin import dp4a
25 26 27 28

DO_TUNING = True
PRETUNED_INDEX = 75333

29
intrin_dp4a = dp4a('local', 'local', 'local')
30 31 32

@autotvm.template
def gemm_int8(n, m, l):
33 34
    A = te.placeholder((n, l), name='A', dtype='int8')
    B = te.placeholder((m, l), name='B', dtype='int8')
35

36 37
    k = te.reduce_axis((0, l), name='k')
    C = te.compute((n, m), lambda i, j: te.sum(A[i, k].astype('int32') * B[j, k].astype(
38 39 40
        'int32'), axis=k), name='C')

    cfg = autotvm.get_config()
41
    s = te.create_schedule(C.op)
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    y, x = C.op.axis

    AA = s.cache_read(A, 'shared', [C])
    BB = s.cache_read(B, 'shared', [C])
    AL = s.cache_read(AA, 'local', [C])
    BL = s.cache_read(BB, 'local', [C])
    CC = s.cache_write(C, 'local')

    k = CC.op.reduce_axis[0]

    cfg.define_split('tile_k', cfg.axis(k), num_outputs=3,
                     filter=lambda entity: entity.size[2] == 4 and \
                     entity.size[0] * 2 >= entity.size[1])

    ko, kt, ki = cfg['tile_k'].apply(s, CC, k)

58
    s[CC].tensorize(ki, intrin_dp4a)
59

60 61 62 63
    block_x = te.thread_axis('blockIdx.x')
    block_y = te.thread_axis('blockIdx.y')
    thread_x = te.thread_axis('threadIdx.x')
    thread_y = te.thread_axis('threadIdx.y')
64 65 66 67 68 69 70 71 72 73 74

    def block_size_filter(entity):
        return entity.size[0] * 2 >= entity.size[1] * 2 and \
                entity.size[1] <= 16 and entity.size[3] <= 4
    cfg.define_split('tile_y', cfg.axis(y), num_outputs=4, filter=block_size_filter)
    cfg.define_split('tile_x', cfg.axis(x), num_outputs=4, filter=block_size_filter)
    by, tyz, ty, yi = cfg['tile_y'].apply(s, C, y)
    bx, txz, tx, xi = cfg['tile_x'].apply(s, C, x)

    s[C].bind(by, block_y)
    s[C].bind(bx, block_x)
75 76
    s[C].bind(tyz, te.thread_axis('vthread'))
    s[C].bind(txz, te.thread_axis('vthread'))
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
    s[C].bind(ty, thread_y)
    s[C].bind(tx, thread_x)
    s[C].reorder(by, bx, tyz, txz, ty, tx, yi, xi)

    s[CC].compute_at(s[C], tx)

    yo, xo = CC.op.axis
    s[CC].reorder(ko, kt, yo, xo, ki)
    s[CC].unroll(kt)

    for stage in [AL, BL]:
        s[stage].compute_at(s[CC], kt)
        _, xi = s[stage].split(stage.op.axis[1], factor=4)
        s[stage].vectorize(xi)
        s[stage].double_buffer()

    cfg.define_knob('storage_align', [16, 48])
    for stage in [AA, BB]:
        s[stage].storage_align(s[stage].op.axis[0],
                               cfg['storage_align'].val, 0)
        s[stage].compute_at(s[CC], ko)

        fused = s[stage].fuse(*s[stage].op.axis)
        ty, tx = s[stage].split(fused, nparts=cfg['tile_y'].size[2])
        tx, xi = s[stage].split(tx, nparts=cfg['tile_x'].size[2])
        _, xi = s[stage].split(xi, factor=16)

        s[stage].bind(ty, thread_y)
        s[stage].bind(tx, thread_x)
        s[stage].vectorize(xi)

    cfg.define_knob('auto_unroll_max_step', [512, 1500])
    s[C].pragma(by, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
    s[C].pragma(by, 'unroll_explicit', False)

    cfg.add_flop(n*m*l*2)
    return s, [A, B, C]


if __name__ == '__main__':
    N = 2048
    n = m = l = N

    logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
    task = autotvm.task.create(gemm_int8, args=(n, m, l), target='cuda')
    print(task.config_space)

    measure_option = autotvm.measure_option(
125 126 127 128
        builder=autotvm.LocalBuilder(),
        runner=autotvm.LocalRunner(repeat=3, min_repeat_ms=100, timeout=4)
    )

129 130 131 132
    log_name = 'gemm_int8.log'
    if DO_TUNING:
        tuner = autotvm.tuner.XGBTuner(task)
        tuner.tune(n_trial=1000, measure_option=measure_option,
133
                   callbacks=[autotvm.callback.log_to_file(log_name)])
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159

        dispatch_context = autotvm.apply_history_best(log_name)
        best_config = dispatch_context.query(task.target, task.workload)
        print('\nBest config:')
        print(best_config)
    else:
        config = task.config_space.get(PRETUNED_INDEX)
        dispatch_context = autotvm.task.ApplyConfig(config)
        print("Using pretuned config:")
        print(config)

    with dispatch_context:
        with tvm.target.create('cuda'):
            s, arg_bufs = gemm_int8(n, m, l)
            f = tvm.build(s, arg_bufs, 'cuda', name='gemm_int8')

    ctx = tvm.context('cuda', 0)

    a_np = np.random.randint(size=(n, l), low=-128, high=127, dtype='int8')
    b_np = np.random.randint(size=(m, l), low=-128, high=127, dtype='int8')

    a = tvm.nd.array(a_np, ctx)
    b = tvm.nd.array(b_np, ctx)
    c = tvm.nd.array(np.zeros((n, m), dtype='int32'), ctx)
    f(a, b, c)

160
    tvm.testing.assert_allclose(
161 162 163 164 165 166 167 168 169 170 171 172 173
        c.asnumpy(),
        np.dot(
            a_np.astype('int32'),
            b_np.T.astype('int32')),
        rtol=1e-5)

    num_ops = 2 * l * m * n
    num_runs = 1000
    timer_f = f.time_evaluator(f.entry_name, ctx, number=num_runs)
    t = timer_f(a, b, c).mean
    GOPS = num_ops / (t * 1e3) / 1e6
    print("average time cost of %d runs = %g ms, %g GOPS." %
          (num_runs, t * 1e3, GOPS))