Commit b5bcdbb0 by Haichen Shen Committed by Zhi

[Fix][VM] Fix VM invoke with set_params (#4079)

* Fix VM invoke with set_params

* add test

* tweak
parent 425430d4
......@@ -577,31 +577,33 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
auto gvit = this->global_map.find(func_name);
CHECK(gvit != this->global_map.end()) << "Cannot find function " << func_name;
auto func_index = gvit->second;
const auto& vm_func = this->functions[func_index];
const auto& param_names = vm_func.params;
auto ctx = this->GetParamsContext();
std::vector<Object> func_args;
// Prepare the func args
std::vector<Object> func_args(param_names.size());
std::vector<size_t> empty_slots;
for (size_t i = 0; i < param_names.size(); ++i) {
const auto& pit = params_.find(param_names[i]);
if (pit != params_.end()) {
func_args[i] = pit->second;
} else {
empty_slots.push_back(i);
}
}
CHECK_EQ(empty_slots.size(), args.size() - 1)
<< "The number of provided parameters doesn't match the number of arguments";
for (int i = 1; i < args.size(); ++i) {
Object obj = CopyTo(args[i], ctx);
func_args.push_back(obj);
}
auto it = std::find_if(functions.begin(), functions.end(),
[func_name](const VMFunction& func) {
return func.name == func_name;
});
CHECK(it != functions.end()) << "Cannot find function " << func_name << "\n";
CHECK_EQ(func_args.size() + params_.size(), it->params.size())
<< "The number of provided parameters doesn't match the number of arguments"
<< "\n";
if (!params_.empty()) {
for (const auto& p : it->params) {
const auto& pit = params_.find(p);
if (pit != params_.end()) {
func_args.push_back(pit->second);
}
}
CHECK_EQ(func_args.size(), it->params.size());
func_args[empty_slots[i - 1]] = obj;
}
*rv = this->Invoke(func_name, func_args);
*rv = this->Invoke(vm_func, func_args);
});
} else if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
......
......@@ -575,6 +575,27 @@ def test_add_op_broadcast():
mod["main"] = func
check_result([x_data, y_data], x_data + y_data, mod=mod)
def test_set_params():
mod = relay.Module()
x = relay.var('x', shape=(10, 5))
w = relay.var('w', shape=(6, 5))
b = relay.var('b', shape=(6,))
y = relay.nn.bias_add(relay.nn.dense(x, w), b)
mod["main"] = relay.Function([x, w, b], y)
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, 'llvm')
vm.init(tvm.cpu())
x_np = np.random.uniform(size=(10, 5)).astype('float32')
w_np = np.random.uniform(size=(6, 5)).astype('float32')
b_np = np.random.uniform(size=(6,)).astype('float32')
ref_np = np.dot(x_np, w_np.T) + b_np
params = {'w': w_np}
vm.load_params(params)
out = vm.run(x_np, b_np)
tvm.testing.assert_allclose(out.asnumpy(), ref_np)
if __name__ == "__main__":
test_id()
test_op()
......@@ -608,3 +629,4 @@ if __name__ == "__main__":
test_add_op_scalar()
test_add_op_tensor()
test_add_op_broadcast()
test_set_params()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment