vta_config.py 6.48 KB
Newer Older
1 2 3 4 5 6 7 8 9
"""VTA config tool"""
import os
import sys
import json
import argparse

def get_pkg_config(cfg):
    """Get the pkg config object."""
    curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
10 11
    proj_root = os.path.abspath(os.path.join(curr_path, "../../"))
    pkg_config_py = os.path.join(proj_root, "vta/python/vta/pkg_config.py")
12 13 14 15 16 17 18 19 20
    libpkg = {"__file__": pkg_config_py}
    exec(compile(open(pkg_config_py, "rb").read(), pkg_config_py, "exec"), libpkg, libpkg)
    PkgConfig = libpkg["PkgConfig"]
    return PkgConfig(cfg, proj_root)


def main():
    """Main funciton"""
    parser = argparse.ArgumentParser()
21 22
    parser.add_argument("--use-cfg", type=str, default="",
                        help="path to the config json")
23 24
    parser.add_argument("--cflags", action="store_true",
                        help="print the cflags")
25 26 27 28
    parser.add_argument("--defs", action="store_true",
                        help="print the macro defs")
    parser.add_argument("--sources", action="store_true",
                        help="print the source file paths")
29 30 31 32 33 34
    parser.add_argument("--update", action="store_true",
                        help="Print out the json option.")
    parser.add_argument("--ldflags", action="store_true",
                        help="print the cflags")
    parser.add_argument("--cfg-json", action="store_true",
                        help="print all the config json")
35 36
    parser.add_argument("--save-cfg-json", type=str, default="",
                        help="save config json to file")
37 38
    parser.add_argument("--target", action="store_true",
                        help="print the target")
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
    parser.add_argument("--cfg-str", action="store_true",
                        help="print the configuration string")
    parser.add_argument("--get-inpwidth", action="store_true",
                        help="returns log of input bitwidth")
    parser.add_argument("--get-wgtwidth", action="store_true",
                        help="returns log of weight bitwidth")
    parser.add_argument("--get-accwidth", action="store_true",
                        help="returns log of accum bitwidth")
    parser.add_argument("--get-outwidth", action="store_true",
                        help="returns log of output bitwidth")
    parser.add_argument("--get-batch", action="store_true",
                        help="returns log of tensor batch dimension")
    parser.add_argument("--get-blockin", action="store_true",
                        help="returns log of tensor block in dimension")
    parser.add_argument("--get-blockout", action="store_true",
                        help="returns log of tensor block out dimension")
    parser.add_argument("--get-uopbuffsize", action="store_true",
                        help="returns log of micro-op buffer size in B")
    parser.add_argument("--get-inpbuffsize", action="store_true",
                        help="returns log of input buffer size in B")
    parser.add_argument("--get-wgtbuffsize", action="store_true",
                        help="returns log of weight buffer size in B")
    parser.add_argument("--get-accbuffsize", action="store_true",
                        help="returns log of accum buffer size in B")
    parser.add_argument("--get-outbuffsize", action="store_true",
                        help="returns log of output buffer size in B")
65 66 67 68
    parser.add_argument("--get-fpgafreq", action="store_true",
                        help="returns FPGA frequency")
    parser.add_argument("--get-fpgaper", action="store_true",
                        help="returns HLS target clock period")
69 70 71 72 73 74 75 76
    args = parser.parse_args()

    if len(sys.argv) == 1:
        parser.print_help()
        return

    curr_path = os.path.dirname(
        os.path.abspath(os.path.expanduser(__file__)))
77
    proj_root = os.path.abspath(os.path.join(curr_path, "../../"))
78
    path_list = [
79 80 81
        os.path.join(proj_root, "vta_config.json"),
        os.path.join(proj_root, "build", "vta_config.json"),
        os.path.join(proj_root, "vta/config/vta_config.json")
82
    ]
83 84
    if args.use_cfg:
        path_list = [args.use_cfg]
85 86 87 88
    ok_path_list = [p for p in path_list if os.path.exists(p)]
    if not ok_path_list:
        raise RuntimeError("Cannot find config in %s" % str(path_list))
    cfg = json.load(open(ok_path_list[0]))
89 90 91 92
    cfg["LOG_OUT_BUFF_SIZE"] = (
        cfg["LOG_ACC_BUFF_SIZE"] +
        cfg["LOG_OUT_WIDTH"] -
        cfg["LOG_ACC_WIDTH"])
93 94 95 96 97
    pkg = get_pkg_config(cfg)

    if args.target:
        print(pkg.target)

98 99 100 101 102 103
    if args.defs:
        print(" ".join(pkg.macro_defs))

    if args.sources:
        print(" ".join(pkg.lib_source))

104
    if args.cflags:
105 106 107 108
        cflags_str = " ".join(pkg.cflags)
        if cfg["TARGET"] == "pynq":
            cflags_str += " -DVTA_TARGET_PYNQ"
        print(cflags_str)
109 110 111 112 113 114 115

    if args.ldflags:
        print(" ".join(pkg.ldflags))

    if args.cfg_json:
        print(pkg.cfg_json)

116 117 118 119
    if args.save_cfg_json:
        with open(args.save_cfg_json, "w") as fo:
            fo.write(pkg.cfg_json)

120
    if args.cfg_str:
121 122
        # Needs to match the BITSTREAM string in python/vta/environment.py
        cfg_str = "{}x{}x{}_{}bx{}b_{}_{}_{}_{}_{}MHz_{}ns_v{}".format(
123 124 125 126 127 128 129 130
            (1 << cfg["LOG_BATCH"]),
            (1 << cfg["LOG_BLOCK_IN"]),
            (1 << cfg["LOG_BLOCK_OUT"]),
            (1 << cfg["LOG_INP_WIDTH"]),
            (1 << cfg["LOG_WGT_WIDTH"]),
            cfg["LOG_UOP_BUFF_SIZE"],
            cfg["LOG_INP_BUFF_SIZE"],
            cfg["LOG_WGT_BUFF_SIZE"],
131 132 133 134 135
            cfg["LOG_ACC_BUFF_SIZE"],
            cfg["HW_FREQ"],
            cfg["HW_CLK_TARGET"],
            cfg["HW_VER"].replace('.', '_'))
        print(cfg_str)
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171

    if args.get_inpwidth:
        print(cfg["LOG_INP_WIDTH"])

    if args.get_wgtwidth:
        print(cfg["LOG_WGT_WIDTH"])

    if args.get_accwidth:
        print(cfg["LOG_ACC_WIDTH"])

    if args.get_outwidth:
        print(cfg["LOG_OUT_WIDTH"])

    if args.get_batch:
        print(cfg["LOG_BATCH"])

    if args.get_blockin:
        print(cfg["LOG_BLOCK_IN"])

    if args.get_blockout:
        print(cfg["LOG_BLOCK_OUT"])

    if args.get_uopbuffsize:
        print(cfg["LOG_UOP_BUFF_SIZE"])

    if args.get_inpbuffsize:
        print(cfg["LOG_INP_BUFF_SIZE"])

    if args.get_wgtbuffsize:
        print(cfg["LOG_WGT_BUFF_SIZE"])

    if args.get_outbuffsize:
        print(cfg["LOG_OUT_BUFF_SIZE"])

    if args.get_accbuffsize:
        print(cfg["LOG_ACC_BUFF_SIZE"])
172

173 174 175 176 177 178
    if args.get_fpgafreq:
        print(cfg["HW_FREQ"])

    if args.get_fpgaper:
        print(cfg["HW_CLK_TARGET"])

179 180
if __name__ == "__main__":
    main()