tune_conv2d.py 5.93 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""Tuning a single conv2d operator"""

from collections import namedtuple
import logging
import os

import tvm
from tvm import autotvm
from tvm.contrib.util import get_lower_ir
import topi
import vta
import vta.testing

env = vta.get_env()

Workload = namedtuple("Conv2DWorkload",
                      ['batch', 'height', 'width', 'in_filter', 'out_filter',
                       'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])

resnet_wkls = [
    # Workloads of resnet18 on imagenet
39 40 41 42 43 44 45 46 47 48 49
    # ('resnet-18.C1',  Workload(env.BATCH, 224, 224, 3,   64,  7, 7, 3, 3, 2, 2)),
    ('resnet-18.C2',  Workload(env.BATCH,  56,  56, 64,  64,  3, 3, 1, 1, 1, 1)),
    ('resnet-18.C3',  Workload(env.BATCH,  56,  56, 64,  128, 3, 3, 1, 1, 2, 2)),
    ('resnet-18.C4',  Workload(env.BATCH,  56,  56, 64,  128, 1, 1, 0, 0, 2, 2)),
    ('resnet-18.C5',  Workload(env.BATCH,  28,  28, 128, 128, 3, 3, 1, 1, 1, 1)),
    ('resnet-18.C6',  Workload(env.BATCH,  28,  28, 128, 256, 3, 3, 1, 1, 2, 2)),
    ('resnet-18.C7',  Workload(env.BATCH,  28,  28, 128, 256, 1, 1, 0, 0, 2, 2)),
    ('resnet-18.C8',  Workload(env.BATCH,  14,  14, 256, 256, 3, 3, 1, 1, 1, 1)),
    ('resnet-18.C9',  Workload(env.BATCH,  14,  14, 256, 512, 3, 3, 1, 1, 2, 2)),
    ('resnet-18.C10', Workload(env.BATCH,  14,  14, 256, 512, 1, 1, 0, 0, 2, 2)),
    ('resnet-18.C11', Workload(env.BATCH,   7,   7, 512, 512, 3, 3, 1, 1, 1, 1)),
50 51 52 53 54 55 56 57 58 59 60
]

@tvm.tag_scope(tag=topi.tag.ELEMWISE)
def my_clip(x, a_min, a_max):
    """Unlike topi's current clip, put min and max into two stages."""
    const_min = tvm.const(a_min, x.dtype)
    const_max = tvm.const(a_max, x.dtype)
    x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA")
    x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
    return x

61
def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation):
62 63 64 65 66 67
    data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN)
    kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN)
    bias_shape = (N//env.BATCH, CO//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)

    data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype)
    kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
68
    bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype)
69 70

    with tvm.target.vta():
71 72 73 74 75 76 77 78 79
        res = topi.nn.conv2d(
            input=data,
            filter=kernel,
            padding=padding,
            strides=strides,
            dilation=dilation,
            layout='NCHW%dn%dc' % (env.BATCH, env.BLOCK_IN),
            out_dtype=env.acc_dtype)
        res = topi.right_shift(res, env.WGT_WIDTH)
80
        res = topi.add(res, bias)
81 82
        res = my_clip(res, 0, (1 << env.OUT_WIDTH - 1) - 1)
        res = topi.cast(res, env.out_dtype)
83 84 85 86 87 88 89 90 91 92 93 94

    if tvm.target.current_target().device_name == 'vta':
        s = topi.generic.schedule_conv2d_nchw([res])
    else:
        s = tvm.create_schedule([res.op])

    return s, [data, kernel, bias, res]

if __name__ == '__main__':

    # Logging config (for printing tuning log to the screen)
    logging.basicConfig()
95 96 97 98 99 100 101 102
    # logging.getLogger('autotvm').setLevel(logging.DEBUG)

    # Tuning log files
    log_file = "%s.conv2d.log" % (env.TARGET)
    # create tmp log file
    tmp_log_file = log_file + ".tmp"
    if os.path.exists(log_file):
        os.remove(log_file)
103 104

    # Get tracker info from env
105 106 107
    tracker_host = os.environ.get("TVM_TRACKER_HOST", None)
    tracker_port = os.environ.get("TVM_TRACKER_PORT", None)
    if not tracker_host or not tracker_port:
108 109 110
        print("Set your AutoTVM tracker node host and port variables to run the autotuner")
        exit()

111 112
    for idx, (wl_name, wl) in enumerate(resnet_wkls):
        prefix = "[Task %2d/%2d] " % (idx, len(resnet_wkls))
113

114
        # Read in workload parameters
115 116 117 118 119 120 121 122 123 124 125
        N = wl.batch
        CI = wl.in_filter
        H = wl.height
        W = wl.width
        CO = wl.out_filter
        KH = wl.hkernel
        KW = wl.wkernel
        strides = (wl.hstride, wl.wstride)
        padding = (wl.hpad, wl.wpad)
        dilation = (1, 1)

126 127 128 129 130 131 132
        # Create task
        task = autotvm.task.create(
                conv2d,
                args=(N, CI, H, W, CO, KH, KW, strides, padding, dilation),
                target=tvm.target.vta(),
                target_host=env.target_host,
                template_key='direct')
133 134
        print(task.config_space)

135
        # Tune
136
        measure_option = autotvm.measure_option(
137 138 139 140 141
                builder=autotvm.LocalBuilder(),
                runner=autotvm.RPCRunner(
                    env.TARGET, host=tracker_host, port=int(tracker_port),
                    number=5, timeout=60,
                    check_correctness=True))
142

143
        # Run Tuner
144
        tuner = autotvm.tuner.RandomTuner(task)
145 146 147 148 149 150 151 152 153 154 155
        tuner.tune(
            n_trial=len(task.config_space),
            early_stopping=None,
            measure_option=measure_option,
            callbacks=[
                    autotvm.callback.progress_bar(len(task.config_space), prefix=prefix),
                    autotvm.callback.log_to_file(tmp_log_file)])

    # Pick best records to a cache file
    autotvm.record.pick_best(tmp_log_file, log_file)
    os.remove(tmp_log_file)