Unverified Commit f8f75ca2 by masahi Committed by GitHub

Expose relay BindParamsByName to Python (#4751)

* expose BindParamByName to python

* fixed alpha equal test
parent 2c0c1849
......@@ -51,6 +51,15 @@ def _update_target(target):
return tgts
def _convert_param_map(params):
inputs = {}
for name, param in params.items():
if isinstance(param, np.ndarray):
param = _nd.array(param)
inputs[name] = _expr.const(param)
return inputs
class BuildModule(object):
"""Build a Relay function to run on TVM graph runtime. This class is used
to expose the `RelayBuildModule` APIs implemented in C++.
......@@ -151,12 +160,7 @@ class BuildModule(object):
def _set_params(self, params):
inputs = {}
for name, param in params.items():
if isinstance(param, np.ndarray):
param = _nd.array(param)
inputs[name] = _expr.const(param)
self._set_params_func(inputs)
self._set_params_func(_convert_param_map(params))
def get_json(self):
"""Return the json file of the built program."""
......@@ -296,6 +300,29 @@ def optimize(mod, target=None, params=None):
return mod, params
def bind_params_by_name(func, params):
"""Bind params to function by name.
This could be useful when assembling custom Relay optimization
passes that involve constant folding.
Parameters
----------
func : relay.Function
The function to bind parameters to.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
func : relay.Function
The function with parameters bound
"""
inputs = _convert_param_map(params)
return _build_module.BindParamsByName(func, inputs)
class GraphExecutor(_interpreter.Executor):
"""Wrapper around Executor interface.
......
......@@ -42,6 +42,43 @@ using TargetsMap = Map<tvm::Integer, tvm::Target>;
using namespace tvm::relay::transform;
/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay::Function BindParamsByName(relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
for (auto arg : func->params) {
const auto& name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
}
}
std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
for (auto& kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = ConstantNode::make(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
CHECK(ret.defined()) << "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}
/*!
* \brief Output of building module
*
*/
......@@ -249,45 +286,6 @@ class RelayBuildModule : public runtime::ModuleNode {
protected:
/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay::Function BindParamsByName(
relay::Function func,
const std::unordered_map<std::string, runtime::NDArray>& params) {
std::unordered_map<std::string, relay::Var> name_dict;
std::unordered_set<relay::Var, ObjectHash, ObjectEqual> repeat_var;
for (auto arg : func->params) {
const auto &name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(arg);
} else {
name_dict[name] = arg;
}
}
std::unordered_map<relay::Var, Expr, ObjectHash, ObjectEqual> bind_dict;
for (auto &kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = ConstantNode::make(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
CHECK(ret.defined())
<< "The returning type is expected to be a Relay Function."
<< "\n";
return ret;
}
/*!
* \brief Optimize a Relay Function.
*
* \param func The input Function where optmization will be applied on.
......@@ -522,6 +520,16 @@ TVM_REGISTER_GLOBAL("relay.build_module._BuildModule")
*rv = RelayBuildCreate();
});
TVM_REGISTER_GLOBAL("relay.build_module.BindParamsByName")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Map<std::string, Constant> params = args[1];
std::unordered_map<std::string, runtime::NDArray> params_;
for (const auto& kv : params) {
params_[kv.first] = kv.second->data;
}
*rv = BindParamsByName(args[0], params_);
});
} // namespace backend
} // namespace relay
} // namespace tvm
......@@ -18,6 +18,8 @@ import numpy as np
import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.testing import run_infer_type, create_workload
def run_opt_pass(expr, opt_pass):
......@@ -161,6 +163,47 @@ def test_fold_full():
assert relay.analysis.graph_equal(zz, zexpected)
def test_fold_batch_norm():
def expected():
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.const(np.zeros((16, 3, 3, 3)))
bias = relay.const(np.zeros((16, 1, 1)))
conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
channels=16, padding=(1, 1))
add = relay.add(conv, bias)
return relay.Function(relay.analysis.free_vars(add), add)
remove_bn_pass = transform.Sequential([
relay.transform.InferType(),
relay.transform.SimplifyInference(),
relay.transform.FoldConstant(),
relay.transform.FoldScaleAxis(),
])
data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("weight")
bn_gamma = relay.var("bn_gamma")
bn_beta = relay.var("bn_beta")
bn_mmean = relay.var("bn_mean")
bn_mvar = relay.var("bn_var")
conv = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
channels=16, padding=(1, 1))
bn_output = relay.nn.batch_norm(conv, bn_gamma, bn_beta,
bn_mmean, bn_mvar)
def initializer(_, param):
param = np.zeros(param.shape)
mod, params = create_workload(bn_output[0], initializer)
mod["main"] = bind_params_by_name(mod["main"], params)
with relay.build_config(opt_level=3):
mod = remove_bn_pass(mod)
expect = run_infer_type(expected())
assert relay.analysis.graph_equal(mod["main"], expect)
if __name__ == "__main__":
test_fold_const()
test_fold_let()
......@@ -168,3 +211,4 @@ if __name__ == "__main__":
test_fold_concat()
test_fold_shape_of()
test_fold_full()
test_fold_batch_norm()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment