Commit 122a4930 by Haichen Shen Committed by Zhi

[Relay][VM] Clean up the VM and VM profiler code (#4391)

* [VM] add a few more API to vm

* [VM][Fix] fix vm convert args

* [VM] a few fixes

* rename fields

* update

* update vm profiler

* x

* add doc

* lint

* fix test

* address comments
parent 1562eaeb
......@@ -22,68 +22,24 @@ Provides extra APIs for profiling vm execution.
"""
from . import vm, _vm
def compile(mod, target=None, target_host=None, params=None):
"""
Parameters
----------
mod : relay.Module
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
exec : Executable
The executable with profiling code.
"""
compiler = VMCompilerProfiler()
target = compiler.update_target(target)
target_host = compiler.update_target_host(target, target_host)
if params:
compiler.set_params(params)
tophub_context = compiler.tophub_context(target)
with tophub_context:
compiler._compile(mod, target, target_host)
return vm.Executable(compiler._get_exec())
def enabled():
"""Whether vm profiler is enabled."""
return hasattr(_vm, "_VMCompilerProfiler")
class VMCompilerProfiler(vm.VMCompiler):
"""Build Relay module to run on VM runtime."""
def __init__(self):
super().__init__()
self.mod = _vm._VMCompilerProfiler()
self._compile = self.mod["compile"]
self._get_exec = self.mod["get_executable"]
self._set_params_func = self.mod["set_params"]
return hasattr(_vm, "_VirtualMachineDebug")
class VirtualMachineProfiler(vm.VirtualMachine):
"""Relay profile VM runtime."""
def __init__(self, mod):
super().__init__(mod)
super(VirtualMachineProfiler, self).__init__(mod)
m = mod.module if isinstance(mod, vm.Executable) else mod
self.mod = _vm._VirtualMachineDebug(m)
self._init = self.mod["init"]
self._invoke = self.mod["invoke"]
self._get_stat = self.mod["get_stat"]
self._set_input = self.mod["set_input"]
self._reset = self.mod["reset"]
def get_stat(self):
return self._get_stat()
def reset(self):
self._reset()
......@@ -34,7 +34,9 @@ Tensor = _obj.Tensor
ADT = _obj.ADT
def _convert(arg, cargs):
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
if isinstance(arg, _obj.Object):
cargs.append(arg)
elif isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.Tensor(arg))
elif isinstance(arg, (tuple, list)):
field_args = []
......@@ -42,7 +44,7 @@ def _convert(arg, cargs):
_convert(field, field_args)
cargs.append(_obj.tuple_object(field_args))
else:
raise "unsupported type"
raise "Unsupported type: %s" % (type(arg))
def convert(args):
......@@ -57,10 +59,13 @@ class Executable(object):
"""Relay VM executable"""
def __init__(self, mod):
self.mod = mod
self._function_params = {}
self._save = self.mod["save"]
self._get_lib = self.mod["get_lib"]
self._get_bytecode = self.mod["get_bytecode"]
self._get_stats = self.mod["get_stats"]
self._get_function_arity = self.mod["get_function_arity"]
self._get_function_param_name = self.mod["get_function_param_name"]
def save(self):
"""Save the Relay VM Executable.
......@@ -239,6 +244,20 @@ class Executable(object):
"""Return the runtime module contained in a virtual machine executable."""
return self.mod
def get_function_params(self, func_name):
"""Get VM Function parameters"""
if func_name in self._function_params:
return self._function_params[func_name]
arity = self._get_function_arity(func_name)
assert arity >= 0
params = []
for i in range(arity):
p = self._get_function_param_name(func_name, i)
assert p
params.append(p)
self._function_params[func_name] = params
return params
class VirtualMachine(object):
"""Relay VM runtime."""
......@@ -248,8 +267,10 @@ class VirtualMachine(object):
"tvm.Module, but received {}".format(type(mod)))
m = mod.module if isinstance(mod, Executable) else mod
self.mod = _vm._VirtualMachine(m)
self._exec = mod
self._init = self.mod["init"]
self._invoke = self.mod["invoke"]
self._set_input = self.mod["set_input"]
def init(self, ctx):
"""Initialize the context in the VM.
......@@ -262,7 +283,37 @@ class VirtualMachine(object):
args = [ctx.device_type, ctx.device_id]
self._init(*args)
def invoke(self, func_name, *args):
def set_input(self, func_name, *args, **kwargs):
"""Set the input to a function.
Parameters
----------
func_name : str
The name of the function.
args : list[NDArray] or list[np.ndarray]
The arguments to the function.
kwargs: dict of str to NDArray or np.ndarray
Named arguments to the function.
"""
if kwargs:
func_params = self._exec.get_function_params(func_name)
new_args = [None] * len(func_params)
assert len(args) + len(kwargs) == len(func_params)
for k in kwargs:
idx = func_params.index(k)
new_args[idx] = kwargs[k]
idx = 0
for i, arg in enumerate(new_args):
if arg is None:
new_args[i] = args[idx]
idx += 1
args = new_args
cargs = convert(args)
self._set_input(func_name, *cargs)
def invoke(self, func_name, *args, **kwargs):
"""Invoke a function.
Parameters
......@@ -273,15 +324,19 @@ class VirtualMachine(object):
args : list[NDArray] or list[np.ndarray]
The arguments to the function.
kwargs: dict of str to NDArray or np.ndarray
Named arguments to the function.
Returns
-------
result : Object
The output.
"""
cargs = convert(args)
return self._invoke(func_name, *cargs)
if args or kwargs:
self.set_input(func_name, *args, **kwargs)
return self._invoke(func_name)
def run(self, *args):
def run(self, *args, **kwargs):
"""Run the main function.
Parameters
......@@ -289,12 +344,15 @@ class VirtualMachine(object):
args : list[NDArray] or list[np.ndarray]
The arguments to the function.
kwargs: dict of str to NDArray or np.ndarray
Named arguments to the function.
Returns
-------
result : Object
The output.
"""
return self.invoke("main", *args)
return self.invoke("main", *args, **kwargs)
def compile(mod, target=None, target_host=None, params=None):
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/backend/vm/profiler/compiler.cc
* \brief A compiler from relay::Module to the VM byte code.
*/
#include "../../../../runtime/vm/profiler/vm.h"
#include "../compiler.h"
namespace tvm {
namespace relay {
namespace vm {
class VMCompilerDebug : public VMCompiler {
public:
VMCompilerDebug() {}
virtual ~VMCompilerDebug() {}
};
runtime::Module CreateVMCompilerDebug() {
auto exec = make_object<VMCompilerDebug>();
return runtime::Module(exec);
}
TVM_REGISTER_GLOBAL("relay._vm._VMCompilerProfiler")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateVMCompilerDebug();
});
} // namespace vm
} // namespace relay
} // namespace tvm
......@@ -30,6 +30,7 @@
#include <algorithm>
#include <memory>
#include <iostream>
#include <iomanip>
#include <sstream>
#include <utility>
#include <vector>
......@@ -67,44 +68,76 @@ PackedFunc Executable::GetFunction(const std::string& name,
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->Save();
});
} else if (name == "get_function_arity") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
*rv = this->GetFunctionArity(func_name);
});
} else if (name == "get_function_param_name") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
int index = args[1];
*rv = this->GetFunctionParameterName(func_name, index);
});
} else {
LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc(nullptr);
}
}
int Executable::GetFunctionArity(std::string func_name) const {
auto it = global_map.find(func_name);
if (it == global_map.end()) {
LOG(ERROR) << "Cannot find function " << func_name << " in executable";
return -1;
}
const auto& func = functions[it->second];
return func.params.size();
}
std::string Executable::GetFunctionParameterName(std::string func_name, uint32_t index) const {
auto it = global_map.find(func_name);
if (it == global_map.end()) {
LOG(ERROR) << "Cannot find function " << func_name << " in executable";
return "";
}
const auto& func = functions[it->second];
if (index > func.params.size()) {
LOG(ERROR) << "Invalid parameter index";
return "";
}
return func.params[index];
}
std::string Executable::GetBytecode() const {
std::ostringstream oss;
for (const auto& func : functions) {
for (size_t i = 0; i < functions.size(); ++i) {
const auto& func = functions[i];
// Print the header of the function format.
oss << "# func name, reg file size, param count, inst count:"
<< std::endl;
oss << func.name << " "
<< func.register_file_size << " "
<< func.params.size() << " "
<< func.instructions.size() << std::endl;
// Print pramams of a `VMFunction`.
oss << "# Parameters: "<< std::endl;
oss << "VM Function[" << i << "]: " << func.name << "(";
for (const auto& param : func.params) {
oss << param << " ";
oss << param << ", ";
}
oss << std::endl;
oss.seekp(-2, std::ios_base::end);
oss << ")" << std::endl;
oss << "# reg file size = " << func.register_file_size << std::endl;
oss << "# instruction count = " << func.instructions.size() << std::endl;
// Print the instructions of a `VMFunction`.
// The part after ";" is the instruction in text format.
oss << "hash, opcode, fields # inst(text):"<< std::endl;
for (const auto& instr : func.instructions) {
oss << "opcode, fields # inst(text):" << std::endl;
for (size_t idx = 0; idx < func.instructions.size(); ++idx) {
const auto& instr = func.instructions[idx];
const auto& serialized_instr = SerializeInstruction(instr);
oss << std::hex << "0x" << serialized_instr.Hash() << " "
<< std::dec << serialized_instr.opcode << " ";
oss << std::setw(2) << idx << ": " << serialized_instr.opcode << " ";
for (auto it : serialized_instr.fields) {
oss << it << " ";
}
oss << " # " << instr;
if (oss.str().back() != '\n') oss << std::endl;
}
oss << std::endl;
}
return oss.str();
......
......@@ -50,15 +50,15 @@ PackedFunc VirtualMachineDebug::GetFunction(
<< "\t"
<< "#Duration(us): Sum/Mean/Min/Max" << std::endl;
for (auto kv : op_durations) {
auto vals = op_durations[kv.first];
for (auto kv : op_durations_) {
auto vals = op_durations_[kv.first];
auto sum = std::accumulate(vals.begin(), vals.end(), 0.0);;
auto mean = sum / static_cast<double>(vals.size());
auto min_value = *std::min_element(vals.begin(), vals.end());
auto max_value = *std::max_element(vals.begin(), vals.end());
os << std::setw(30) << std::left << packed_index_map[kv.first] << "\t"
<< std::setw(10) << std::left << op_invokes[kv.first] << "\t"
os << std::setw(30) << std::left << packed_index_map_[kv.first] << "\t"
<< std::setw(10) << std::left << op_invokes_[kv.first] << "\t"
<< sum << "/" << mean << "/" << min_value << "/" << max_value << std::endl;
total_duration += sum;
......@@ -66,18 +66,10 @@ PackedFunc VirtualMachineDebug::GetFunction(
os << "Total Duration " << total_duration << " us" << std::endl;
*rv = os.str();
});
} else if (name == "init") {
} else if (name == "reset") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size() % 2, 0);
std::vector<TVMContext> contexts;
for (int i = 0; i < args.size() / 2; ++i) {
TVMContext ctx;
int device_type = args[i * 2];
ctx.device_type = DLDeviceType(device_type);
ctx.device_id = args[i * 2 + 1];
contexts.push_back(ctx);
}
this->Init(contexts);
op_durations_.clear();
op_invokes_.clear();
});
} else {
return VirtualMachine::GetFunction(name, sptr_to_self);
......@@ -86,31 +78,25 @@ PackedFunc VirtualMachineDebug::GetFunction(
void VirtualMachineDebug::LoadExecutable(const Executable* exec) {
VirtualMachine::LoadExecutable(exec);
CHECK(this->exec);
for (auto kv : this->exec->primitive_map) {
packed_index_map[kv.second] = kv.first;
op_invokes[kv.second] = 0;
CHECK(exec_);
for (auto kv : exec_->primitive_map) {
packed_index_map_[kv.second] = kv.first;
op_invokes_[kv.second] = 0;
}
}
void VirtualMachineDebug::Init(const std::vector<TVMContext>& ctxs) {
VirtualMachine::Init(ctxs);
}
void VirtualMachineDebug::InvokePacked(Index packed_index,
const PackedFunc& func, Index arg_count,
Index output_size,
const std::vector<ObjectRef>& args) {
CHECK(this->exec);
CHECK(exec_);
auto ctx = this->GetParamsContext();
// warmup
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
args);
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args);
TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
auto op_begin = std::chrono::high_resolution_clock::now();
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
args);
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args);
TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
auto op_end = std::chrono::high_resolution_clock::now();
double op_duration =
......@@ -118,8 +104,8 @@ void VirtualMachineDebug::InvokePacked(Index packed_index,
op_begin)
.count();
op_durations[packed_index].push_back(op_duration * 1e6);
op_invokes[packed_index] += 1;
op_durations_[packed_index].push_back(op_duration * 1e6);
op_invokes_[packed_index] += 1;
}
runtime::Module CreateVirtualMachineDebug(const Executable* exec) {
......
......@@ -43,19 +43,17 @@ class VirtualMachineDebug : public VirtualMachine {
PackedFunc GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) final;
void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
Index output_size, const std::vector<ObjectRef>& args) final;
void LoadExecutable(const Executable* exec);
void LoadExecutable(const Executable* exec) final;
~VirtualMachineDebug() {}
private:
void Init(const std::vector<TVMContext>& ctxs);
void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
Index output_size, const std::vector<ObjectRef>& args) final;
std::unordered_map<Index, std::string> packed_index_map;
std::unordered_map<Index, std::vector<double>> op_durations;
std::unordered_map<Index, int> op_invokes;
std::unordered_map<Index, std::string> packed_index_map_;
std::unordered_map<Index, std::vector<double>> op_durations_;
std::unordered_map<Index, int> op_invokes_;
};
} // namespace vm
......
......@@ -47,18 +47,13 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
if isinstance(f, relay.Expr):
mod = relay.Module()
mod["main"] = f
exe = relay.vm.compile(mod, target)
vm = relay.vm.VirtualMachine(exe)
vm.init(ctx)
return vm.invoke("main", *args)
else:
assert isinstance(f, relay.Module), "expected expression or module"
mod = f
exe = relay.vm.compile(mod, target)
vm = relay.vm.VirtualMachine(exe)
vm.init(ctx)
ret = vm.invoke("main", *args)
return ret
exe = relay.vm.compile(mod, target)
vm = relay.vm.VirtualMachine(exe)
vm.init(ctx)
return vm.invoke("main", *args)
def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vm.Tensor):
......@@ -577,35 +572,4 @@ def test_add_op_broadcast():
if __name__ == "__main__":
test_id()
test_op()
test_cond()
test_simple_if()
test_simple_call()
test_count_loop()
test_sum_loop()
test_tuple_fst()
test_tuple_second()
test_let_scalar()
test_let_tensor()
test_split()
test_split_no_fuse()
test_list_constructor()
test_let_tensor()
test_let_scalar()
test_compose()
test_list_hd()
test_list_tl_empty_list()
test_list_tl()
test_list_nth()
test_list_update()
test_list_length()
test_list_map()
test_list_foldl()
test_list_foldr()
test_list_sum()
test_list_filter()
test_closure()
test_add_op_scalar()
test_add_op_tensor()
test_add_op_broadcast()
pytest.main()
......@@ -107,9 +107,9 @@ def test_serializer():
assert any(item.startswith('fused_multiply') for item in prim_ops)
code = exe.bytecode
assert "main 8 2 8" in code
assert "f1 5 1 6" in code
assert "f2 5 1 6" in code
assert "main(x1, y1)" in code
assert "f1(x)" in code
assert "f2(y)" in code
code, lib = exe.save()
assert isinstance(code, bytearray)
......
......@@ -28,7 +28,7 @@ def test_basic():
ctx = tvm.cpu()
if not relay.profiler_vm.enabled():
return
exe = relay.profiler_vm.compile(mod, target, params=params)
exe = relay.vm.compile(mod, target, params=params)
vm = relay.profiler_vm.VirtualMachineProfiler(exe)
vm.init(ctx)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment