param_dict.py 1.66 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
# pylint: disable=invalid-name
"""Helper utility to save parameter dicts."""
import tvm

_save_param_dict = tvm.get_global_func("tvm.relay._save_param_dict")
_load_param_dict = tvm.get_global_func("tvm.relay._load_param_dict")

def save_param_dict(params):
    """Save parameter dictionary to binary bytes.

    The result binary bytes can be loaded by the
    GraphModule with API "load_params".

    Parameters
    ----------
    params : dict of str to NDArray
        The parameter dictionary.

    Returns
    -------
    param_bytes: bytearray
        Serialized parameters.

    Examples
    --------
    .. code-block:: python

       # compile and save the modules to file.
       graph, lib, params = tvm.relay.build(func, target=target, params=params)
       module = graph_runtime.create(graph, lib, tvm.gpu(0))
       # save the parameters as byte array
       param_bytes = tvm.relay.save_param_dict(params)
       # We can serialize the param_bytes and load it back later.
       # Pass in byte array to module to directly set parameters
       module.load_params(param_bytes)
    """
    args = []
    for k, v in params.items():
        args.append(k)
        args.append(tvm.nd.array(v))
    return _save_param_dict(*args)


def load_param_dict(param_bytes):
    """Load parameter dictionary to binary bytes.

    Parameters
    ----------
    param_bytes: bytearray
        Serialized parameters.

    Returns
    -------
    params : dict of str to NDArray
        The parameter dictionary.
    """
    if isinstance(param_bytes, (bytes, str)):
        param_bytes = bytearray(param_bytes)
    load_arr = _load_param_dict(param_bytes)
    return {v.name : v.array for v in load_arr}