Commit 122a4930 by Haichen Shen Committed by Zhi

[Relay][VM] Clean up the VM and VM profiler code (#4391)

* [VM] add a few more API to vm

* [VM][Fix] fix vm convert args

* [VM] a few fixes

* rename fields

* update

* update vm profiler

* x

* add doc

* lint

* fix test

* address comments
parent 1562eaeb
...@@ -268,125 +268,142 @@ struct Instruction { ...@@ -268,125 +268,142 @@ struct Instruction {
} alloc_storage; } alloc_storage;
}; };
/*! \brief Construct a return instruction. /*!
* \param return_reg The register containing the return value. * \brief Construct a return instruction.
* \return The return instruction. * \param return_reg The register containing the return value.
* */ * \return The return instruction.
*/
static Instruction Ret(RegName return_reg); static Instruction Ret(RegName return_reg);
/*! \brief Construct a fatal instruction. /*!
* \return The fatal instruction. * \brief Construct a fatal instruction.
* */ * \return The fatal instruction.
*/
static Instruction Fatal(); static Instruction Fatal();
/*! \brief Construct a invoke packed instruction. /*!
* \param packed_index The index of the packed function. * \brief Construct a invoke packed instruction.
* \param arity The arity of the function. * \param packed_index The index of the packed function.
* \param output_size The number of outputs of the packed function. * \param arity The arity of the function.
* \param args The argument registers. * \param output_size The number of outputs of the packed function.
* \return The invoke packed instruction. * \param args The argument registers.
* \return The invoke packed instruction.
*/ */
static Instruction InvokePacked(Index packed_index, Index arity, Index output_size, static Instruction InvokePacked(Index packed_index, Index arity, Index output_size,
const std::vector<RegName>& args); const std::vector<RegName>& args);
/*! \brief Construct an allocate tensor instruction with constant shape. /*!
* \param storage The storage to allocate out of. * \brief Construct an allocate tensor instruction with constant shape.
* \param shape The shape of the tensor. * \param storage The storage to allocate out of.
* \param dtype The dtype of the tensor. * \param shape The shape of the tensor.
* \param dst The destination register. * \param dtype The dtype of the tensor.
* \return The allocate tensor instruction. * \param dst The destination register.
* \return The allocate tensor instruction.
*/ */
static Instruction AllocTensor(RegName storage, static Instruction AllocTensor(RegName storage,
const std::vector<int64_t>& shape, DLDataType dtype, RegName dst); const std::vector<int64_t>& shape, DLDataType dtype, RegName dst);
/*! \brief Construct an allocate tensor instruction with register. /*!
* \param storage The storage to allocate out of. * \brief Construct an allocate tensor instruction with register.
* \param shape_register The register containing the shape. * \param storage The storage to allocate out of.
* \param dtype The dtype of the tensor. * \param shape_register The register containing the shape.
* \param dst The destination register. * \param dtype The dtype of the tensor.
* \return The allocate tensor instruction. * \param dst The destination register.
* \return The allocate tensor instruction.
*/ */
static Instruction AllocTensorReg(RegName storage, static Instruction AllocTensorReg(RegName storage,
RegName shape_register, DLDataType dtype, RegName dst); RegName shape_register, DLDataType dtype, RegName dst);
/*! \brief Construct an allocate datatype instruction. /*!
* \param tag The datatype tag. * \brief Construct an allocate datatype instruction.
* \param num_fields The number of fields for the datatype. * \param tag The datatype tag.
* \param fields The registers containing the fields. * \param num_fields The number of fields for the datatype.
* \param dst The register name of the destination. * \param fields The registers containing the fields.
* \return The allocate instruction tensor. * \param dst The register name of the destination.
* \return The allocate instruction tensor.
*/ */
static Instruction AllocADT(Index tag, Index num_fields, const std::vector<RegName>& fields, static Instruction AllocADT(Index tag, Index num_fields, const std::vector<RegName>& fields,
RegName dst); RegName dst);
/*! \brief Construct an allocate closure instruction. /*!
* \param func_index The index of the function table. * \brief Construct an allocate closure instruction.
* \param num_freevar The number of free variables. * \param func_index The index of the function table.
* \param free_vars The registers of the free variables. * \param num_freevar The number of free variables.
* \param dst The destination register. * \param free_vars The registers of the free variables.
* \return The allocate closure instruction. * \param dst The destination register.
* \return The allocate closure instruction.
*/ */
static Instruction AllocClosure(Index func_index, Index num_freevar, static Instruction AllocClosure(Index func_index, Index num_freevar,
const std::vector<RegName>& free_vars, RegName dst); const std::vector<RegName>& free_vars, RegName dst);
/*! \brief Construct a get field instruction. /*!
* \param object_reg The register containing the object to project from. * \brief Construct a get field instruction.
* \param field_index The field to read out of the object. * \param object_reg The register containing the object to project from.
* \param dst The destination register. * \param field_index The field to read out of the object.
* \return The get field instruction. * \param dst The destination register.
* \return The get field instruction.
*/ */
static Instruction GetField(RegName object_reg, Index field_index, RegName dst); static Instruction GetField(RegName object_reg, Index field_index, RegName dst);
/*! \brief Construct a get_tag instruction. /*!
* \param object_reg The register containing the object to project from. * \brief Construct a get_tag instruction.
* \param dst The destination register. * \param object_reg The register containing the object to project from.
* \return The get_tag instruction. * \param dst The destination register.
* \return The get_tag instruction.
*/ */
static Instruction GetTag(RegName object_reg, RegName dst); static Instruction GetTag(RegName object_reg, RegName dst);
/*! \brief Construct an if instruction. /*!
* \param test The register containing the test value. * \brief Construct an if instruction.
* \param target The register containing the target value. * \param test The register containing the test value.
* \param true_branch The offset to the true branch. * \param target The register containing the target value.
* \param false_branch The offset to the false branch. * \param true_branch The offset to the true branch.
* \return The if instruction. * \param false_branch The offset to the false branch.
* \return The if instruction.
*/ */
static Instruction If(RegName test, RegName target, Index true_branch, Index false_branch); static Instruction If(RegName test, RegName target, Index true_branch, Index false_branch);
/*! \brief Construct a goto instruction. /*!
* \param pc_offset The offset from the current pc. * \brief Construct a goto instruction.
* \return The goto instruction. * \param pc_offset The offset from the current pc.
* \return The goto instruction.
*/ */
static Instruction Goto(Index pc_offset); static Instruction Goto(Index pc_offset);
/*! \brief Construct an invoke instruction. /*!
* \param func_index The index of the function to invoke. * \brief Construct an invoke instruction.
* \param args The registers containing the arguments. * \param func_index The index of the function to invoke.
* \param dst The destination register. * \param args The registers containing the arguments.
* \return The invoke instruction. * \param dst The destination register.
* \return The invoke instruction.
*/ */
static Instruction Invoke(Index func_index, const std::vector<RegName>& args, RegName dst); static Instruction Invoke(Index func_index, const std::vector<RegName>& args, RegName dst);
/*! \brief Construct an invoke closure instruction. /*!
* \param closure The register of the closure to invoke. * \brief Construct an invoke closure instruction.
* \param args The registers containing the arguments. * \param closure The register of the closure to invoke.
* \param dst The destination register. * \param args The registers containing the arguments.
* \return The invoke closure instruction. * \param dst The destination register.
* \return The invoke closure instruction.
*/ */
static Instruction InvokeClosure(RegName closure, const std::vector<RegName>& args, RegName dst); static Instruction InvokeClosure(RegName closure, const std::vector<RegName>& args, RegName dst);
/*! \brief Construct a load constant instruction. /*!
* \param const_index The index of the constant. * \brief Construct a load constant instruction.
* \param dst The destination register. * \param const_index The index of the constant.
* \return The load constant instruction. * \param dst The destination register.
* \return The load constant instruction.
*/ */
static Instruction LoadConst(Index const_index, RegName dst); static Instruction LoadConst(Index const_index, RegName dst);
/*! \brief Construct a load_constanti instruction. /*!
* \param val The interger constant value. * \brief Construct a load_constanti instruction.
* \param dst The destination register. * \param val The interger constant value.
* \return The load_constanti instruction. * \param dst The destination register.
* \return The load_constanti instruction.
*/ */
static Instruction LoadConsti(Index val, RegName dst); static Instruction LoadConsti(Index val, RegName dst);
/*! \brief Construct a move instruction. /*!
* \param src The source register. * \brief Construct a move instruction.
* \param dst The destination register. * \param src The source register.
* \return The move instruction. * \param dst The destination register.
* \return The move instruction.
*/ */
static Instruction Move(RegName src, RegName dst); static Instruction Move(RegName src, RegName dst);
/*! \brief Allocate a storage block. /*!
* \param size The size of the allocation. * \brief Allocate a storage block.
* \param alignment The allocation's alignment. * \param size The size of the allocation.
* \param dtype_hint The data type hint for the allocator. * \param alignment The allocation's alignment.
* \param dst The destination to place the storage. * \param dtype_hint The data type hint for the allocator.
* \return The alloc storage instruction. * \param dst The destination to place the storage.
* \return The alloc storage instruction.
*/ */
static Instruction AllocStorage(RegName size, RegName alignment, static Instruction AllocStorage(RegName size, RegName alignment,
DLDataType dtype_hint, RegName dst); DLDataType dtype_hint, RegName dst);
...@@ -399,7 +416,8 @@ struct Instruction { ...@@ -399,7 +416,8 @@ struct Instruction {
friend std::ostream& operator<<(std::ostream& os, const Instruction&); friend std::ostream& operator<<(std::ostream& os, const Instruction&);
}; };
/*! \brief A representation of a Relay function in the VM. /*!
* \brief A representation of a Relay function in the VM.
* *
* Contains metadata about the compiled function, as * Contains metadata about the compiled function, as
* well as the compiled VM instructions. * well as the compiled VM instructions.
...@@ -427,7 +445,8 @@ struct VMFunction { ...@@ -427,7 +445,8 @@ struct VMFunction {
friend std::ostream& operator<<(std::ostream& os, const VMFunction&); friend std::ostream& operator<<(std::ostream& os, const VMFunction&);
}; };
/*! \brief A representation of a stack frame. /*!
* \brief A representation of a stack frame.
* *
* A stack frame is a record containing the information needed * A stack frame is a record containing the information needed
* to restore the caller's virtual machine state after returning * to restore the caller's virtual machine state after returning
...@@ -458,7 +477,8 @@ struct VMFrame { ...@@ -458,7 +477,8 @@ struct VMFrame {
caller_return_register(0) {} caller_return_register(0) {}
}; };
/*! \brief The executable emitted by the VM compiler. /*!
* \brief The executable emitted by the VM compiler.
* *
* The executable contains information (e.g. data in different memory regions) * The executable contains information (e.g. data in different memory regions)
* to run in a virtual machine. * to run in a virtual machine.
...@@ -534,19 +554,35 @@ class Executable : public ModuleNode { ...@@ -534,19 +554,35 @@ class Executable : public ModuleNode {
*/ */
std::string GetBytecode() const; std::string GetBytecode() const;
/*! /*!
* \brief Print the detailed statistics of the given code, i.e. number of * \brief Print the detailed statistics of the given code, i.e. number of
* globls and constants, etc. * globls and constants, etc.
*/ */
std::string Stats() const; std::string Stats() const;
/*! \brief Get the `lib` module in an executable. Users have the flexibility to call /*!
* \brief Get the `lib` module in an executable. Users have the flexibility to call
* `export_library` from the frontend to save the library to disk. * `export_library` from the frontend to save the library to disk.
* *
* \return The runtime module that contains the hardwre dependent code. * \return The runtime module that contains the hardwre dependent code.
*/ */
runtime::Module GetLib() const { return lib; } runtime::Module GetLib() const { return lib; }
/*!
* \brief Get the arity of the VM Fucntion.
* \param func Function name.
* \return The number of parameters.
*/
int GetFunctionArity(std::string func) const;
/*!
* \brief Get the parameter name given the function name and parameter index.
* \param func Function name.
* \param index Parameter index.
* \return The parameter name.
*/
std::string GetFunctionParameterName(std::string func, uint32_t index) const;
virtual ~Executable() {} virtual ~Executable() {}
const char* type_key() const final { const char* type_key() const final {
...@@ -628,7 +664,8 @@ class Executable : public ModuleNode { ...@@ -628,7 +664,8 @@ class Executable : public ModuleNode {
std::string code_; std::string code_;
}; };
/*! \brief The virtual machine. /*!
* \brief The virtual machine.
* *
* The virtual machine contains all the current execution state, * The virtual machine contains all the current execution state,
* as well as the executable. * as well as the executable.
...@@ -660,83 +697,72 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -660,83 +697,72 @@ class VirtualMachine : public runtime::ModuleNode {
virtual PackedFunc GetFunction(const std::string& name, virtual PackedFunc GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self); const ObjectPtr<Object>& sptr_to_self);
/*!
* \brief Invoke a PackedFunction
*
* \param packed_index The offset of the PackedFunction in all functions.
* \param func The PackedFunction to be invoked.
* \param arg_count The number of arguments to the PackedFunction.
* \param output_size The number of outputs of the PackedFunction.
* \param args Arguments to the PackedFunction.
*
* \note The return value will be stored in the last output_size slots of args.
*/
virtual void InvokePacked(Index packed_index,
const PackedFunc& func,
Index arg_count,
Index output_size,
const std::vector<ObjectRef>& args);
virtual ~VirtualMachine() {} virtual ~VirtualMachine() {}
const char* type_key() const final { const char* type_key() const final {
return "VirtualMachine"; return "VirtualMachine";
} }
VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {} VirtualMachine() : frames_(), func_index_(0), code_(nullptr), pc_(0), exec_(nullptr) {}
/*! \brief load the executable for the virtual machine. /*!
* \param exec The executable. * \brief load the executable for the virtual machine.
* \param exec The executable.
*/ */
void LoadExecutable(const Executable* exec); virtual void LoadExecutable(const Executable* exec);
protected: protected:
/*! \brief The virtual machine's packed function table. */ /*! \brief The virtual machine's packed function table. */
std::vector<PackedFunc> packed_funcs; std::vector<PackedFunc> packed_funcs_;
/*! \brief The current stack of call frames. */ /*! \brief The current stack of call frames. */
std::vector<VMFrame> frames; std::vector<VMFrame> frames_;
/*! \brief The fuction table index of the current function. */ /*! \brief The fuction table index of the current function. */
Index func_index; Index func_index_;
/*! \brief The current pointer to the code section. */ /*! \brief The current pointer to the code section. */
const Instruction* code; const Instruction* code_;
/*! \brief The virtual machine PC. */ /*! \brief The virtual machine PC. */
Index pc; Index pc_;
/*! \brief The special return register. */ /*! \brief The special return register. */
ObjectRef return_register; ObjectRef return_register_;
/*! \brief The executable the VM will operate on. */ /*! \brief The executable the VM will operate on. */
const Executable* exec; const Executable* exec_;
/*! \brief The function name to inputs mapping. */
std::unordered_map<std::string, std::vector<ObjectRef>> inputs_;
/*! \brief The set of TVM contexts the VM is currently executing on. */ /*! \brief The set of TVM contexts the VM is currently executing on. */
std::vector<TVMContext> ctxs; std::vector<TVMContext> ctxs_;
/*! \brief Push a call frame on to the call stack. */ /*! \brief Push a call frame on to the call stack. */
void PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func); void PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func);
/*! \brief Pop a frame off the call stack.
* \return The number of frames left. /*!
* \brief Pop a frame off the call stack.
* \return The number of frames left.
*/ */
Index PopFrame(); Index PopFrame();
/*! \brief Write to a VM register. /*!
* \param reg The register to write to. * \brief Write to a VM register.
* \param obj The object to write to. * \param reg The register to write to.
* \param obj The object to write to.
*/ */
inline void WriteRegister(RegName reg, const ObjectRef& obj); inline void WriteRegister(RegName reg, const ObjectRef& obj);
/*! \brief Read a VM register. /*!
* \param reg The register to read from. * \brief Read a VM register.
* \return The read object. * \param reg The register to read from.
* \return The read object.
*/ */
inline ObjectRef ReadRegister(RegName reg) const; inline ObjectRef ReadRegister(RegName reg) const;
/*! \brief Read a VM register and cast it to int32_t /*!
* \param reg The register to read from. * \brief Read a VM register and cast it to int32_t
* \return The read scalar. * \param reg The register to read from.
* \return The read scalar.
*/ */
int32_t LoadScalarInt(RegName reg) const; int32_t LoadScalarInt(RegName reg) const;
/*! \brief Invoke a VM function. /*!
* \brief Invoke a VM function.
* \param func The function. * \param func The function.
* \param args The arguments to the function. * \param args The arguments to the function.
* \return The object representing the result. * \return The object representing the result.
...@@ -752,29 +778,43 @@ class VirtualMachine : public runtime::ModuleNode { ...@@ -752,29 +778,43 @@ class VirtualMachine : public runtime::ModuleNode {
*/ */
ObjectRef Invoke(const std::string& name, const std::vector<ObjectRef>& args); ObjectRef Invoke(const std::string& name, const std::vector<ObjectRef>& args);
/*! \brief Initialize the virtual machine for a set of contexts. /*!
* \param contexts The set of TVM contexts. * \brief Invoke a PackedFunction
*
* \param packed_index The offset of the PackedFunction in all functions.
* \param func The PackedFunction to be invoked.
* \param arg_count The number of arguments to the PackedFunction.
* \param output_size The number of outputs of the PackedFunction.
* \param args Arguments to the PackedFunction.
*
* \note The return value will be stored in the last output_size slots of args.
*/
virtual void InvokePacked(Index packed_index,
const PackedFunc& func,
Index arg_count,
Index output_size,
const std::vector<ObjectRef>& args);
/*!
* \brief Initialize the virtual machine for a set of contexts.
* \param contexts The set of TVM contexts.
*/ */
void Init(const std::vector<TVMContext>& contexts); void Init(const std::vector<TVMContext>& contexts);
/*! \brief Run VM dispatch loop. /*! \brief Run VM dispatch loop. */
*/
void RunLoop(); void RunLoop();
/*! \brief Get device context for params. /*! \brief Get device context for params. */
*/
TVMContext GetParamsContext() const; TVMContext GetParamsContext() const;
private: private:
/*! \brief Invoke a global setting up the VM state to execute. /*!
* \brief Invoke a global setting up the VM state to execute.
* *
* This does not begin execution of the VM. * This does not begin execution of the VM.
*/ */
void InvokeGlobal(const VMFunction& func, const std::vector<ObjectRef>& args); void InvokeGlobal(const VMFunction& func, const std::vector<ObjectRef>& args);
/*! \brief The parameter name to data mapping. */
std::unordered_map<std::string, ObjectRef> params_;
/*! /*!
* \brief The constant pool for runtime. It caches the device dependent * \brief The constant pool for runtime. It caches the device dependent
* object to avoid rellocation of constants during inference. * object to avoid rellocation of constants during inference.
......
...@@ -22,68 +22,24 @@ Provides extra APIs for profiling vm execution. ...@@ -22,68 +22,24 @@ Provides extra APIs for profiling vm execution.
""" """
from . import vm, _vm from . import vm, _vm
def compile(mod, target=None, target_host=None, params=None):
"""
Parameters
----------
mod : relay.Module
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
exec : Executable
The executable with profiling code.
"""
compiler = VMCompilerProfiler()
target = compiler.update_target(target)
target_host = compiler.update_target_host(target, target_host)
if params:
compiler.set_params(params)
tophub_context = compiler.tophub_context(target)
with tophub_context:
compiler._compile(mod, target, target_host)
return vm.Executable(compiler._get_exec())
def enabled(): def enabled():
"""Whether vm profiler is enabled.""" """Whether vm profiler is enabled."""
return hasattr(_vm, "_VMCompilerProfiler") return hasattr(_vm, "_VirtualMachineDebug")
class VMCompilerProfiler(vm.VMCompiler):
"""Build Relay module to run on VM runtime."""
def __init__(self):
super().__init__()
self.mod = _vm._VMCompilerProfiler()
self._compile = self.mod["compile"]
self._get_exec = self.mod["get_executable"]
self._set_params_func = self.mod["set_params"]
class VirtualMachineProfiler(vm.VirtualMachine): class VirtualMachineProfiler(vm.VirtualMachine):
"""Relay profile VM runtime.""" """Relay profile VM runtime."""
def __init__(self, mod): def __init__(self, mod):
super().__init__(mod) super(VirtualMachineProfiler, self).__init__(mod)
m = mod.module if isinstance(mod, vm.Executable) else mod m = mod.module if isinstance(mod, vm.Executable) else mod
self.mod = _vm._VirtualMachineDebug(m) self.mod = _vm._VirtualMachineDebug(m)
self._init = self.mod["init"] self._init = self.mod["init"]
self._invoke = self.mod["invoke"] self._invoke = self.mod["invoke"]
self._get_stat = self.mod["get_stat"] self._get_stat = self.mod["get_stat"]
self._set_input = self.mod["set_input"]
self._reset = self.mod["reset"]
def get_stat(self): def get_stat(self):
return self._get_stat() return self._get_stat()
def reset(self):
self._reset()
...@@ -34,7 +34,9 @@ Tensor = _obj.Tensor ...@@ -34,7 +34,9 @@ Tensor = _obj.Tensor
ADT = _obj.ADT ADT = _obj.ADT
def _convert(arg, cargs): def _convert(arg, cargs):
if isinstance(arg, (np.ndarray, tvm.nd.NDArray)): if isinstance(arg, _obj.Object):
cargs.append(arg)
elif isinstance(arg, (np.ndarray, tvm.nd.NDArray)):
cargs.append(_obj.Tensor(arg)) cargs.append(_obj.Tensor(arg))
elif isinstance(arg, (tuple, list)): elif isinstance(arg, (tuple, list)):
field_args = [] field_args = []
...@@ -42,7 +44,7 @@ def _convert(arg, cargs): ...@@ -42,7 +44,7 @@ def _convert(arg, cargs):
_convert(field, field_args) _convert(field, field_args)
cargs.append(_obj.tuple_object(field_args)) cargs.append(_obj.tuple_object(field_args))
else: else:
raise "unsupported type" raise "Unsupported type: %s" % (type(arg))
def convert(args): def convert(args):
...@@ -57,10 +59,13 @@ class Executable(object): ...@@ -57,10 +59,13 @@ class Executable(object):
"""Relay VM executable""" """Relay VM executable"""
def __init__(self, mod): def __init__(self, mod):
self.mod = mod self.mod = mod
self._function_params = {}
self._save = self.mod["save"] self._save = self.mod["save"]
self._get_lib = self.mod["get_lib"] self._get_lib = self.mod["get_lib"]
self._get_bytecode = self.mod["get_bytecode"] self._get_bytecode = self.mod["get_bytecode"]
self._get_stats = self.mod["get_stats"] self._get_stats = self.mod["get_stats"]
self._get_function_arity = self.mod["get_function_arity"]
self._get_function_param_name = self.mod["get_function_param_name"]
def save(self): def save(self):
"""Save the Relay VM Executable. """Save the Relay VM Executable.
...@@ -239,6 +244,20 @@ class Executable(object): ...@@ -239,6 +244,20 @@ class Executable(object):
"""Return the runtime module contained in a virtual machine executable.""" """Return the runtime module contained in a virtual machine executable."""
return self.mod return self.mod
def get_function_params(self, func_name):
"""Get VM Function parameters"""
if func_name in self._function_params:
return self._function_params[func_name]
arity = self._get_function_arity(func_name)
assert arity >= 0
params = []
for i in range(arity):
p = self._get_function_param_name(func_name, i)
assert p
params.append(p)
self._function_params[func_name] = params
return params
class VirtualMachine(object): class VirtualMachine(object):
"""Relay VM runtime.""" """Relay VM runtime."""
...@@ -248,8 +267,10 @@ class VirtualMachine(object): ...@@ -248,8 +267,10 @@ class VirtualMachine(object):
"tvm.Module, but received {}".format(type(mod))) "tvm.Module, but received {}".format(type(mod)))
m = mod.module if isinstance(mod, Executable) else mod m = mod.module if isinstance(mod, Executable) else mod
self.mod = _vm._VirtualMachine(m) self.mod = _vm._VirtualMachine(m)
self._exec = mod
self._init = self.mod["init"] self._init = self.mod["init"]
self._invoke = self.mod["invoke"] self._invoke = self.mod["invoke"]
self._set_input = self.mod["set_input"]
def init(self, ctx): def init(self, ctx):
"""Initialize the context in the VM. """Initialize the context in the VM.
...@@ -262,7 +283,37 @@ class VirtualMachine(object): ...@@ -262,7 +283,37 @@ class VirtualMachine(object):
args = [ctx.device_type, ctx.device_id] args = [ctx.device_type, ctx.device_id]
self._init(*args) self._init(*args)
def invoke(self, func_name, *args): def set_input(self, func_name, *args, **kwargs):
"""Set the input to a function.
Parameters
----------
func_name : str
The name of the function.
args : list[NDArray] or list[np.ndarray]
The arguments to the function.
kwargs: dict of str to NDArray or np.ndarray
Named arguments to the function.
"""
if kwargs:
func_params = self._exec.get_function_params(func_name)
new_args = [None] * len(func_params)
assert len(args) + len(kwargs) == len(func_params)
for k in kwargs:
idx = func_params.index(k)
new_args[idx] = kwargs[k]
idx = 0
for i, arg in enumerate(new_args):
if arg is None:
new_args[i] = args[idx]
idx += 1
args = new_args
cargs = convert(args)
self._set_input(func_name, *cargs)
def invoke(self, func_name, *args, **kwargs):
"""Invoke a function. """Invoke a function.
Parameters Parameters
...@@ -273,15 +324,19 @@ class VirtualMachine(object): ...@@ -273,15 +324,19 @@ class VirtualMachine(object):
args : list[NDArray] or list[np.ndarray] args : list[NDArray] or list[np.ndarray]
The arguments to the function. The arguments to the function.
kwargs: dict of str to NDArray or np.ndarray
Named arguments to the function.
Returns Returns
------- -------
result : Object result : Object
The output. The output.
""" """
cargs = convert(args) if args or kwargs:
return self._invoke(func_name, *cargs) self.set_input(func_name, *args, **kwargs)
return self._invoke(func_name)
def run(self, *args): def run(self, *args, **kwargs):
"""Run the main function. """Run the main function.
Parameters Parameters
...@@ -289,12 +344,15 @@ class VirtualMachine(object): ...@@ -289,12 +344,15 @@ class VirtualMachine(object):
args : list[NDArray] or list[np.ndarray] args : list[NDArray] or list[np.ndarray]
The arguments to the function. The arguments to the function.
kwargs: dict of str to NDArray or np.ndarray
Named arguments to the function.
Returns Returns
------- -------
result : Object result : Object
The output. The output.
""" """
return self.invoke("main", *args) return self.invoke("main", *args, **kwargs)
def compile(mod, target=None, target_host=None, params=None): def compile(mod, target=None, target_host=None, params=None):
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/backend/vm/profiler/compiler.cc
* \brief A compiler from relay::Module to the VM byte code.
*/
#include "../../../../runtime/vm/profiler/vm.h"
#include "../compiler.h"
namespace tvm {
namespace relay {
namespace vm {
class VMCompilerDebug : public VMCompiler {
public:
VMCompilerDebug() {}
virtual ~VMCompilerDebug() {}
};
runtime::Module CreateVMCompilerDebug() {
auto exec = make_object<VMCompilerDebug>();
return runtime::Module(exec);
}
TVM_REGISTER_GLOBAL("relay._vm._VMCompilerProfiler")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateVMCompilerDebug();
});
} // namespace vm
} // namespace relay
} // namespace tvm
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <iostream> #include <iostream>
#include <iomanip>
#include <sstream> #include <sstream>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -67,44 +68,76 @@ PackedFunc Executable::GetFunction(const std::string& name, ...@@ -67,44 +68,76 @@ PackedFunc Executable::GetFunction(const std::string& name,
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->Save(); *rv = this->Save();
}); });
} else if (name == "get_function_arity") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
*rv = this->GetFunctionArity(func_name);
});
} else if (name == "get_function_param_name") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::string func_name = args[0];
int index = args[1];
*rv = this->GetFunctionParameterName(func_name, index);
});
} else { } else {
LOG(FATAL) << "Unknown packed function: " << name; LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc(nullptr); return PackedFunc(nullptr);
} }
} }
int Executable::GetFunctionArity(std::string func_name) const {
auto it = global_map.find(func_name);
if (it == global_map.end()) {
LOG(ERROR) << "Cannot find function " << func_name << " in executable";
return -1;
}
const auto& func = functions[it->second];
return func.params.size();
}
std::string Executable::GetFunctionParameterName(std::string func_name, uint32_t index) const {
auto it = global_map.find(func_name);
if (it == global_map.end()) {
LOG(ERROR) << "Cannot find function " << func_name << " in executable";
return "";
}
const auto& func = functions[it->second];
if (index > func.params.size()) {
LOG(ERROR) << "Invalid parameter index";
return "";
}
return func.params[index];
}
std::string Executable::GetBytecode() const { std::string Executable::GetBytecode() const {
std::ostringstream oss; std::ostringstream oss;
for (const auto& func : functions) { for (size_t i = 0; i < functions.size(); ++i) {
const auto& func = functions[i];
// Print the header of the function format. // Print the header of the function format.
oss << "# func name, reg file size, param count, inst count:" oss << "VM Function[" << i << "]: " << func.name << "(";
<< std::endl;
oss << func.name << " "
<< func.register_file_size << " "
<< func.params.size() << " "
<< func.instructions.size() << std::endl;
// Print pramams of a `VMFunction`.
oss << "# Parameters: "<< std::endl;
for (const auto& param : func.params) { for (const auto& param : func.params) {
oss << param << " "; oss << param << ", ";
} }
oss << std::endl; oss.seekp(-2, std::ios_base::end);
oss << ")" << std::endl;
oss << "# reg file size = " << func.register_file_size << std::endl;
oss << "# instruction count = " << func.instructions.size() << std::endl;
// Print the instructions of a `VMFunction`. // Print the instructions of a `VMFunction`.
// The part after ";" is the instruction in text format. // The part after ";" is the instruction in text format.
oss << "hash, opcode, fields # inst(text):"<< std::endl; oss << "opcode, fields # inst(text):" << std::endl;
for (const auto& instr : func.instructions) { for (size_t idx = 0; idx < func.instructions.size(); ++idx) {
const auto& instr = func.instructions[idx];
const auto& serialized_instr = SerializeInstruction(instr); const auto& serialized_instr = SerializeInstruction(instr);
oss << std::hex << "0x" << serialized_instr.Hash() << " " oss << std::setw(2) << idx << ": " << serialized_instr.opcode << " ";
<< std::dec << serialized_instr.opcode << " ";
for (auto it : serialized_instr.fields) { for (auto it : serialized_instr.fields) {
oss << it << " "; oss << it << " ";
} }
oss << " # " << instr; oss << " # " << instr;
if (oss.str().back() != '\n') oss << std::endl; if (oss.str().back() != '\n') oss << std::endl;
} }
oss << std::endl;
} }
return oss.str(); return oss.str();
......
...@@ -50,15 +50,15 @@ PackedFunc VirtualMachineDebug::GetFunction( ...@@ -50,15 +50,15 @@ PackedFunc VirtualMachineDebug::GetFunction(
<< "\t" << "\t"
<< "#Duration(us): Sum/Mean/Min/Max" << std::endl; << "#Duration(us): Sum/Mean/Min/Max" << std::endl;
for (auto kv : op_durations) { for (auto kv : op_durations_) {
auto vals = op_durations[kv.first]; auto vals = op_durations_[kv.first];
auto sum = std::accumulate(vals.begin(), vals.end(), 0.0);; auto sum = std::accumulate(vals.begin(), vals.end(), 0.0);;
auto mean = sum / static_cast<double>(vals.size()); auto mean = sum / static_cast<double>(vals.size());
auto min_value = *std::min_element(vals.begin(), vals.end()); auto min_value = *std::min_element(vals.begin(), vals.end());
auto max_value = *std::max_element(vals.begin(), vals.end()); auto max_value = *std::max_element(vals.begin(), vals.end());
os << std::setw(30) << std::left << packed_index_map[kv.first] << "\t" os << std::setw(30) << std::left << packed_index_map_[kv.first] << "\t"
<< std::setw(10) << std::left << op_invokes[kv.first] << "\t" << std::setw(10) << std::left << op_invokes_[kv.first] << "\t"
<< sum << "/" << mean << "/" << min_value << "/" << max_value << std::endl; << sum << "/" << mean << "/" << min_value << "/" << max_value << std::endl;
total_duration += sum; total_duration += sum;
...@@ -66,18 +66,10 @@ PackedFunc VirtualMachineDebug::GetFunction( ...@@ -66,18 +66,10 @@ PackedFunc VirtualMachineDebug::GetFunction(
os << "Total Duration " << total_duration << " us" << std::endl; os << "Total Duration " << total_duration << " us" << std::endl;
*rv = os.str(); *rv = os.str();
}); });
} else if (name == "init") { } else if (name == "reset") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size() % 2, 0); op_durations_.clear();
std::vector<TVMContext> contexts; op_invokes_.clear();
for (int i = 0; i < args.size() / 2; ++i) {
TVMContext ctx;
int device_type = args[i * 2];
ctx.device_type = DLDeviceType(device_type);
ctx.device_id = args[i * 2 + 1];
contexts.push_back(ctx);
}
this->Init(contexts);
}); });
} else { } else {
return VirtualMachine::GetFunction(name, sptr_to_self); return VirtualMachine::GetFunction(name, sptr_to_self);
...@@ -86,31 +78,25 @@ PackedFunc VirtualMachineDebug::GetFunction( ...@@ -86,31 +78,25 @@ PackedFunc VirtualMachineDebug::GetFunction(
void VirtualMachineDebug::LoadExecutable(const Executable* exec) { void VirtualMachineDebug::LoadExecutable(const Executable* exec) {
VirtualMachine::LoadExecutable(exec); VirtualMachine::LoadExecutable(exec);
CHECK(this->exec); CHECK(exec_);
for (auto kv : this->exec->primitive_map) { for (auto kv : exec_->primitive_map) {
packed_index_map[kv.second] = kv.first; packed_index_map_[kv.second] = kv.first;
op_invokes[kv.second] = 0; op_invokes_[kv.second] = 0;
} }
} }
void VirtualMachineDebug::Init(const std::vector<TVMContext>& ctxs) {
VirtualMachine::Init(ctxs);
}
void VirtualMachineDebug::InvokePacked(Index packed_index, void VirtualMachineDebug::InvokePacked(Index packed_index,
const PackedFunc& func, Index arg_count, const PackedFunc& func, Index arg_count,
Index output_size, Index output_size,
const std::vector<ObjectRef>& args) { const std::vector<ObjectRef>& args) {
CHECK(this->exec); CHECK(exec_);
auto ctx = this->GetParamsContext(); auto ctx = this->GetParamsContext();
// warmup // warmup
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args);
args);
TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
auto op_begin = std::chrono::high_resolution_clock::now(); auto op_begin = std::chrono::high_resolution_clock::now();
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args);
args);
TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); TVMSynchronize(ctx.device_type, ctx.device_id, nullptr);
auto op_end = std::chrono::high_resolution_clock::now(); auto op_end = std::chrono::high_resolution_clock::now();
double op_duration = double op_duration =
...@@ -118,8 +104,8 @@ void VirtualMachineDebug::InvokePacked(Index packed_index, ...@@ -118,8 +104,8 @@ void VirtualMachineDebug::InvokePacked(Index packed_index,
op_begin) op_begin)
.count(); .count();
op_durations[packed_index].push_back(op_duration * 1e6); op_durations_[packed_index].push_back(op_duration * 1e6);
op_invokes[packed_index] += 1; op_invokes_[packed_index] += 1;
} }
runtime::Module CreateVirtualMachineDebug(const Executable* exec) { runtime::Module CreateVirtualMachineDebug(const Executable* exec) {
......
...@@ -43,19 +43,17 @@ class VirtualMachineDebug : public VirtualMachine { ...@@ -43,19 +43,17 @@ class VirtualMachineDebug : public VirtualMachine {
PackedFunc GetFunction(const std::string& name, PackedFunc GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) final; const ObjectPtr<Object>& sptr_to_self) final;
void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, void LoadExecutable(const Executable* exec) final;
Index output_size, const std::vector<ObjectRef>& args) final;
void LoadExecutable(const Executable* exec);
~VirtualMachineDebug() {} ~VirtualMachineDebug() {}
private: private:
void Init(const std::vector<TVMContext>& ctxs); void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
Index output_size, const std::vector<ObjectRef>& args) final;
std::unordered_map<Index, std::string> packed_index_map; std::unordered_map<Index, std::string> packed_index_map_;
std::unordered_map<Index, std::vector<double>> op_durations; std::unordered_map<Index, std::vector<double>> op_durations_;
std::unordered_map<Index, int> op_invokes; std::unordered_map<Index, int> op_invokes_;
}; };
} // namespace vm } // namespace vm
......
...@@ -544,7 +544,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { ...@@ -544,7 +544,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
break; break;
} }
case Opcode::If: { case Opcode::If: {
os << "if " << "$" << instr.if_op.test << " " << instr.if_op.target << " " os << "if " << "$" << instr.if_op.test << " $" << instr.if_op.target << " "
<< instr.if_op.true_offset << " " << instr.if_op.false_offset; << instr.if_op.true_offset << " " << instr.if_op.false_offset;
break; break;
} }
...@@ -565,7 +565,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { ...@@ -565,7 +565,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
break; break;
} }
case Opcode::LoadConsti: { case Opcode::LoadConsti: {
os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]"; os << "load_consti $" << instr.dst << " " << instr.load_consti.val;
break; break;
} }
case Opcode::GetField: { case Opcode::GetField: {
...@@ -630,35 +630,20 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, ...@@ -630,35 +630,20 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
if (name == "invoke") { if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK(exec) << "The executable is not created yet."; CHECK(exec_) << "The executable is not created yet.";
std::string func_name = args[0]; std::string func_name = args[0];
auto gvit = exec->global_map.find(func_name); auto git = exec_->global_map.find(func_name);
CHECK(gvit != exec->global_map.end()) << "Cannot find function " << func_name; CHECK(git != exec_->global_map.end())
auto func_index = gvit->second; << "Cannot find function " << func_name << " in the executable";
const auto& vm_func = exec->functions[func_index]; auto func = exec_->functions[git->second];
const auto& param_names = vm_func.params; if (func.params.empty()) {
auto ctx = this->GetParamsContext(); *rv = Invoke(func, {});
} else {
// Prepare the func args auto it = inputs_.find(func_name);
std::vector<ObjectRef> func_args(param_names.size()); CHECK(it != inputs_.end()) << "Input has not been set for function " << func_name;
std::vector<size_t> empty_slots; const std::vector<ObjectRef> &func_args = it->second;
*rv = Invoke(func, func_args);
for (size_t i = 0; i < param_names.size(); ++i) {
const auto& pit = params_.find(param_names[i]);
if (pit != params_.end()) {
func_args[i] = pit->second;
} else {
empty_slots.push_back(i);
}
}
CHECK_EQ(empty_slots.size(), args.size() - 1)
<< "The number of provided parameters doesn't match the number of arguments";
for (int i = 1; i < args.size(); ++i) {
ObjectRef obj = CopyTo(args[i], ctx);
func_args[empty_slots[i - 1]] = obj;
} }
*rv = this->Invoke(vm_func, func_args);
}); });
} else if (name == "init") { } else if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
...@@ -673,6 +658,27 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, ...@@ -673,6 +658,27 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
} }
this->Init(contexts); this->Init(contexts);
}); });
} else if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK(exec_) << "The executable is not created yet.";
std::string func_name = args[0];
auto gvit = exec_->global_map.find(func_name);
CHECK(gvit != exec_->global_map.end()) << "Cannot find function " << func_name;
auto func_index = gvit->second;
const auto& vm_func = exec_->functions[func_index];
const auto& param_names = vm_func.params;
// TODO(icemelon9): For heterogeneous execution, get input device information
TVMContext ctx = ctxs_[0];
CHECK_EQ(args.size() - 1, param_names.size()) <<
"The number of provided parameters doesn't match the number of arguments";
std::vector<ObjectRef> func_args(param_names.size());
for (int i = 1; i < args.size(); ++i) {
ObjectRef obj = CopyTo(args[i], ctx);
func_args[i - 1] = obj;
}
inputs_.erase(func_name);
inputs_.emplace(func_name, func_args);
});
} else { } else {
LOG(FATAL) << "Unknown packed function: " << name; LOG(FATAL) << "Unknown packed function: " << name;
return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {});
...@@ -680,47 +686,46 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, ...@@ -680,47 +686,46 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
} }
TVMContext VirtualMachine::GetParamsContext() const { TVMContext VirtualMachine::GetParamsContext() const {
CHECK(!ctxs.empty()) << "Context has not been initialized yet." CHECK(!ctxs_.empty()) << "Context has not been initialized yet.";
<< "\n";
// Use the fallback device if no device index is available. // Use the fallback device if no device index is available.
int fallback_device_type = static_cast<int>(ctxs[0].device_type); int fallback_device_type = static_cast<int>(ctxs_[0].device_type);
// TODO(wweic): For heterogeneous execution, get device information from byte // TODO(wweic): For heterogeneous execution, get device information from byte
const auto& cit = const auto& cit =
std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) { std::find_if(ctxs_.begin(), ctxs_.end(), [&fallback_device_type](const TVMContext& c) {
return fallback_device_type == static_cast<int>(c.device_type); return fallback_device_type == static_cast<int>(c.device_type);
}); });
return (cit == ctxs.end() ? ctxs[0] : *cit); return (cit == ctxs_.end() ? ctxs_[0] : *cit);
} }
void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) { void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) {
auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size); auto frame = VMFrame(ret_pc, func_index_, arg_count, code_, vm_func.register_file_size);
frames.push_back(frame); frames_.push_back(frame);
} }
Index VirtualMachine::PopFrame() { Index VirtualMachine::PopFrame() {
CHECK_GT(frames.size(), 0); CHECK_GT(frames_.size(), 0);
const VMFrame& fr = frames.back(); const VMFrame& fr = frames_.back();
func_index = fr.func_index; func_index_ = fr.func_index;
code = fr.code; code_ = fr.code;
pc = fr.pc; pc_ = fr.pc;
auto call_stack_size = frames.size(); auto call_stack_size = frames_.size();
frames.pop_back(); frames_.pop_back();
return call_stack_size; return call_stack_size;
} }
void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<ObjectRef>& args) { void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<ObjectRef>& args) {
DLOG(INFO) << "Invoking global " << func.name << " " << args.size(); DLOG(INFO) << "Invoking global " << func.name << " " << args.size();
PushFrame(func.params.size(), this->pc + 1, func); PushFrame(func.params.size(), this->pc_ + 1, func);
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
WriteRegister(i, args[i]); WriteRegister(i, args[i]);
} }
DLOG(INFO) << "func.params= " << func.params.size(); DLOG(INFO) << "func.params= " << func.params.size();
code = func.instructions.data(); code_ = func.instructions.data();
pc = 0; pc_ = 0;
} }
ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<ObjectRef>& args) { ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<ObjectRef>& args) {
...@@ -729,16 +734,19 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<Objec ...@@ -729,16 +734,19 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<Objec
InvokeGlobal(func, args); InvokeGlobal(func, args);
RunLoop(); RunLoop();
// TODO(wweic) ctx could be obtained from the ctxs list. // TODO(wweic) ctx could be obtained from the ctxs list.
auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]); auto alloc = MemoryManager::Global()->GetAllocator(ctxs_[0]);
DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B"; DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B";
return return_register; return return_register_;
} }
ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector<ObjectRef>& args) { ObjectRef VirtualMachine::Invoke(const std::string& name, const std::vector<ObjectRef>& args) {
CHECK(exec) << "The executable has not been created yet."; CHECK(exec_) << "The executable has not been created yet.";
auto func_index = exec->global_map.at(name); auto it = exec_->global_map.find(name);
DLOG(INFO) << "Invoke Global " << name << " at index " << func_index; CHECK(it != exec_->global_map.end())
return Invoke(exec->functions[func_index], args); << "Cannot find function " << name << " in the executable";
auto func_index_ = it->second;
DLOG(INFO) << "Invoke Global " << name << " at index " << func_index_;
return Invoke(exec_->functions[func_index_], args);
} }
void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
...@@ -777,34 +785,34 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, ...@@ -777,34 +785,34 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
void VirtualMachine::LoadExecutable(const Executable* exec) { void VirtualMachine::LoadExecutable(const Executable* exec) {
CHECK(exec) << "The executable is not created yet."; CHECK(exec) << "The executable is not created yet.";
this->exec = exec; exec_ = exec;
runtime::Module lib = this->exec->lib; runtime::Module lib = exec_->lib;
// Get the list of packed functions. // Get the list of packed functions.
CHECK(exec->primitive_map.empty() || lib.operator->()) CHECK(exec->primitive_map.empty() || lib.operator->())
<< "runtime module should have been built for primitive functions" << "runtime module should have been built for primitive functions"
<< "\n"; << "\n";
for (const auto& it : this->exec->primitive_map) { for (const auto& it : exec_->primitive_map) {
const auto& packed_name = it.first; const auto& packed_name = it.first;
auto packed_index = static_cast<size_t>(it.second); auto packed_index = static_cast<size_t>(it.second);
if (packed_funcs.size() <= packed_index) { if (packed_funcs_.size() <= packed_index) {
packed_funcs.resize(packed_index + 1); packed_funcs_.resize(packed_index + 1);
} }
packed_funcs[packed_index] = lib.GetFunction(packed_name); packed_funcs_[packed_index] = lib.GetFunction(packed_name);
} }
} }
void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) { void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) {
this->ctxs = ctxs; ctxs_ = ctxs;
} }
inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) { inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) {
frames.back().register_file[r] = val; frames_.back().register_file[r] = val;
} }
inline ObjectRef VirtualMachine::ReadRegister(Index r) const { inline ObjectRef VirtualMachine::ReadRegister(Index r) const {
return frames.back().register_file[r]; return frames_.back().register_file[r];
} }
inline int32_t VirtualMachine::LoadScalarInt(Index r) const { inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
...@@ -825,14 +833,14 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const { ...@@ -825,14 +833,14 @@ inline int32_t VirtualMachine::LoadScalarInt(Index r) const {
} }
void VirtualMachine::RunLoop() { void VirtualMachine::RunLoop() {
CHECK(this->code); CHECK(this->exec_);
CHECK(this->exec); CHECK(this->code_);
this->pc = 0; pc_ = 0;
Index frame_start = frames.size(); Index frame_start = frames_.size();
while (true) { while (true) {
main_loop: main_loop:
auto const& instr = this->code[this->pc]; auto const& instr = code_[this->pc_];
DLOG(INFO) << "Executing(" << pc << "): " << instr; DLOG(INFO) << "Executing(" << pc_ << "): " << instr;
#if USE_RELAY_DEBUG #if USE_RELAY_DEBUG
InstructionPrint(std::cout, instr); InstructionPrint(std::cout, instr);
#endif // USE_RELAY_DEBUG #endif // USE_RELAY_DEBUG
...@@ -842,14 +850,14 @@ void VirtualMachine::RunLoop() { ...@@ -842,14 +850,14 @@ void VirtualMachine::RunLoop() {
ObjectRef from_obj; ObjectRef from_obj;
from_obj = ReadRegister(instr.from); from_obj = ReadRegister(instr.from);
WriteRegister(instr.dst, from_obj); WriteRegister(instr.dst, from_obj);
pc++; pc_++;
goto main_loop; goto main_loop;
} }
case Opcode::Fatal: { case Opcode::Fatal: {
throw std::runtime_error("VM encountered fatal error"); throw std::runtime_error("VM encountered fatal error");
} }
case Opcode::LoadConst: { case Opcode::LoadConst: {
auto constant_obj = exec->constants[instr.const_index]; auto constant_obj = exec_->constants[instr.const_index];
// We cache the allocated object in the constant pool. To measure, the // We cache the allocated object in the constant pool. To measure, the
// first iteration will set the pool up. The other iterations will // first iteration will set the pool up. The other iterations will
// directly reuse the allocated objects. // directly reuse the allocated objects.
...@@ -859,17 +867,17 @@ void VirtualMachine::RunLoop() { ...@@ -859,17 +867,17 @@ void VirtualMachine::RunLoop() {
if (!const_pool_[instr.const_index].defined()) { if (!const_pool_[instr.const_index].defined()) {
// TODO(wweic) ctx could be obtained from the ctxs list. // TODO(wweic) ctx could be obtained from the ctxs list.
const_pool_[instr.const_index] = CopyTo(constant_obj, ctxs[0]); const_pool_[instr.const_index] = CopyTo(constant_obj, ctxs_[0]);
} }
WriteRegister(instr.dst, const_pool_[instr.const_index]); WriteRegister(instr.dst, const_pool_[instr.const_index]);
pc++; pc_++;
goto main_loop; goto main_loop;
} }
case Opcode::LoadConsti: { case Opcode::LoadConsti: {
auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0}); auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0});
reinterpret_cast<int64_t*>(tensor->data)[0] = instr.load_consti.val; reinterpret_cast<int64_t*>(tensor->data)[0] = instr.load_consti.val;
WriteRegister(instr.dst, Tensor(tensor)); WriteRegister(instr.dst, Tensor(tensor));
pc++; pc_++;
goto main_loop; goto main_loop;
} }
case Opcode::Invoke: { case Opcode::Invoke: {
...@@ -877,14 +885,13 @@ void VirtualMachine::RunLoop() { ...@@ -877,14 +885,13 @@ void VirtualMachine::RunLoop() {
for (Index i = 0; i < instr.num_args; ++i) { for (Index i = 0; i < instr.num_args; ++i) {
args.push_back(ReadRegister(instr.invoke_args_registers[i])); args.push_back(ReadRegister(instr.invoke_args_registers[i]));
} }
InvokeGlobal(exec->functions[instr.func_index], args); InvokeGlobal(exec_->functions[instr.func_index], args);
frames.back().caller_return_register = instr.dst; frames_.back().caller_return_register = instr.dst;
goto main_loop; goto main_loop;
} }
case Opcode::InvokePacked: { case Opcode::InvokePacked: {
DLOG(INFO) << "InvokedPacked " DLOG(INFO) << "InvokedPacked " << "arity=" << instr.arity;
<< "arity=" << instr.arity; const auto& func = packed_funcs_[instr.packed_index];
const auto& func = packed_funcs[instr.packed_index];
const auto& arity = instr.arity; const auto& arity = instr.arity;
std::vector<ObjectRef> args; std::vector<ObjectRef> args;
for (Index i = 0; i < arity; ++i) { for (Index i = 0; i < arity; ++i) {
...@@ -897,7 +904,7 @@ void VirtualMachine::RunLoop() { ...@@ -897,7 +904,7 @@ void VirtualMachine::RunLoop() {
// We no longer need to write the registers back, we write directly // We no longer need to write the registers back, we write directly
// through the registers mutably. // through the registers mutably.
InvokePacked(instr.packed_index, func, arity, instr.output_size, args); InvokePacked(instr.packed_index, func, arity, instr.output_size, args);
pc++; pc_++;
goto main_loop; goto main_loop;
} }
case Opcode::InvokeClosure: { case Opcode::InvokeClosure: {
...@@ -911,8 +918,8 @@ void VirtualMachine::RunLoop() { ...@@ -911,8 +918,8 @@ void VirtualMachine::RunLoop() {
for (Index i = 0; i < instr.num_closure_args; ++i) { for (Index i = 0; i < instr.num_closure_args; ++i) {
args.push_back(ReadRegister(instr.closure_args[i])); args.push_back(ReadRegister(instr.closure_args[i]));
} }
InvokeGlobal(exec->functions[closure->func_index], args); InvokeGlobal(exec_->functions[closure->func_index], args);
frames.back().caller_return_register = instr.dst; frames_.back().caller_return_register = instr.dst;
goto main_loop; goto main_loop;
} }
case Opcode::GetField: { case Opcode::GetField: {
...@@ -923,7 +930,7 @@ void VirtualMachine::RunLoop() { ...@@ -923,7 +930,7 @@ void VirtualMachine::RunLoop() {
<< object->type_index(); << object->type_index();
auto field = tuple->fields[instr.field_index]; auto field = tuple->fields[instr.field_index];
WriteRegister(instr.dst, field); WriteRegister(instr.dst, field);
pc++; pc_++;
goto main_loop; goto main_loop;
} }
case Opcode::GetTag: { case Opcode::GetTag: {
...@@ -937,11 +944,11 @@ void VirtualMachine::RunLoop() { ...@@ -937,11 +944,11 @@ void VirtualMachine::RunLoop() {
auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0}); auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0});
reinterpret_cast<int32_t*>(tag_tensor->data)[0] = tag; reinterpret_cast<int32_t*>(tag_tensor->data)[0] = tag;
WriteRegister(instr.dst, Tensor(tag_tensor)); WriteRegister(instr.dst, Tensor(tag_tensor));
pc++; pc_++;
goto main_loop; goto main_loop;
} }
case Opcode::Goto: { case Opcode::Goto: {
pc += instr.pc_offset; pc_ += instr.pc_offset;
goto main_loop; goto main_loop;
} }
case Opcode::If: { case Opcode::If: {
...@@ -950,10 +957,10 @@ void VirtualMachine::RunLoop() { ...@@ -950,10 +957,10 @@ void VirtualMachine::RunLoop() {
if (test_val == target_val) { if (test_val == target_val) {
CHECK_NE(instr.if_op.true_offset, 0); CHECK_NE(instr.if_op.true_offset, 0);
pc += instr.if_op.true_offset; pc_ += instr.if_op.true_offset;
} else { } else {
CHECK_NE(instr.if_op.false_offset, 0); CHECK_NE(instr.if_op.false_offset, 0);
pc += instr.if_op.false_offset; pc_ += instr.if_op.false_offset;
} }
goto main_loop; goto main_loop;
...@@ -971,7 +978,7 @@ void VirtualMachine::RunLoop() { ...@@ -971,7 +978,7 @@ void VirtualMachine::RunLoop() {
auto obj = Tensor(data); auto obj = Tensor(data);
WriteRegister(instr.dst, obj); WriteRegister(instr.dst, obj);
pc++; pc_++;
goto main_loop; goto main_loop;
} }
case Opcode::AllocTensorReg: { case Opcode::AllocTensorReg: {
...@@ -996,7 +1003,7 @@ void VirtualMachine::RunLoop() { ...@@ -996,7 +1003,7 @@ void VirtualMachine::RunLoop() {
auto obj = Tensor(data); auto obj = Tensor(data);
WriteRegister(instr.dst, obj); WriteRegister(instr.dst, obj);
pc++; pc_++;
goto main_loop; goto main_loop;
} }
case Opcode::AllocADT: { case Opcode::AllocADT: {
...@@ -1006,7 +1013,7 @@ void VirtualMachine::RunLoop() { ...@@ -1006,7 +1013,7 @@ void VirtualMachine::RunLoop() {
} }
ObjectRef obj = ADT(instr.constructor_tag, fields); ObjectRef obj = ADT(instr.constructor_tag, fields);
WriteRegister(instr.dst, obj); WriteRegister(instr.dst, obj);
pc++; pc_++;
goto main_loop; goto main_loop;
} }
case Opcode::AllocClosure: { case Opcode::AllocClosure: {
...@@ -1015,7 +1022,7 @@ void VirtualMachine::RunLoop() { ...@@ -1015,7 +1022,7 @@ void VirtualMachine::RunLoop() {
free_vars.push_back(ReadRegister(instr.free_vars[i])); free_vars.push_back(ReadRegister(instr.free_vars[i]));
} }
WriteRegister(instr.dst, Closure(instr.func_index, free_vars)); WriteRegister(instr.dst, Closure(instr.func_index, free_vars));
pc++; pc_++;
goto main_loop; goto main_loop;
} }
case Opcode::AllocStorage: { case Opcode::AllocStorage: {
...@@ -1027,23 +1034,23 @@ void VirtualMachine::RunLoop() { ...@@ -1027,23 +1034,23 @@ void VirtualMachine::RunLoop() {
"alignment=" << alignment << "alignment=" << alignment <<
"dtype_hint=" << TVMType2String(instr.alloc_storage.dtype_hint); "dtype_hint=" << TVMType2String(instr.alloc_storage.dtype_hint);
auto storage = make_storage(size, alignment, instr.alloc_storage.dtype_hint, ctxs[0]); auto storage = make_storage(size, alignment, instr.alloc_storage.dtype_hint, ctxs_[0]);
WriteRegister(instr.dst, storage); WriteRegister(instr.dst, storage);
pc++; pc_++;
goto main_loop; goto main_loop;
} }
case Opcode::Ret: { case Opcode::Ret: {
// If we have hit the point from which we started // If we have hit the point from which we started
// running, we should return to the caller breaking // running, we should return to the caller breaking
// the dispatch loop. // the dispatch loop.
return_register = ReadRegister(instr.result); return_register_ = ReadRegister(instr.result);
auto caller_return_register = frames.back().caller_return_register; auto caller_return_register = frames_.back().caller_return_register;
if (PopFrame() == frame_start) { if (PopFrame() == frame_start) {
return; return;
// Otherwise we are just returning from a local call. // Otherwise we are just returning from a local call.
} else { } else {
WriteRegister(caller_return_register, return_register); WriteRegister(caller_return_register, return_register_);
goto main_loop; goto main_loop;
} }
} }
...@@ -1061,8 +1068,7 @@ TVM_REGISTER_GLOBAL("relay._vm._VirtualMachine") ...@@ -1061,8 +1068,7 @@ TVM_REGISTER_GLOBAL("relay._vm._VirtualMachine")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0]; runtime::Module mod = args[0];
const auto* exec = dynamic_cast<Executable*>(mod.operator->()); const auto* exec = dynamic_cast<Executable*>(mod.operator->());
CHECK(exec) << "The virtual machine executable has not been defined yet." CHECK(exec) << "The virtual machine executable has not been defined yet.";
<< "\n";
*rv = CreateVirtualMachine(exec); *rv = CreateVirtualMachine(exec);
}); });
......
...@@ -47,18 +47,13 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"): ...@@ -47,18 +47,13 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
if isinstance(f, relay.Expr): if isinstance(f, relay.Expr):
mod = relay.Module() mod = relay.Module()
mod["main"] = f mod["main"] = f
exe = relay.vm.compile(mod, target)
vm = relay.vm.VirtualMachine(exe)
vm.init(ctx)
return vm.invoke("main", *args)
else: else:
assert isinstance(f, relay.Module), "expected expression or module" assert isinstance(f, relay.Module), "expected expression or module"
mod = f mod = f
exe = relay.vm.compile(mod, target) exe = relay.vm.compile(mod, target)
vm = relay.vm.VirtualMachine(exe) vm = relay.vm.VirtualMachine(exe)
vm.init(ctx) vm.init(ctx)
ret = vm.invoke("main", *args) return vm.invoke("main", *args)
return ret
def vmobj_to_list(o): def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.vm.Tensor): if isinstance(o, tvm.relay.backend.vm.Tensor):
...@@ -577,35 +572,4 @@ def test_add_op_broadcast(): ...@@ -577,35 +572,4 @@ def test_add_op_broadcast():
if __name__ == "__main__": if __name__ == "__main__":
test_id() pytest.main()
test_op()
test_cond()
test_simple_if()
test_simple_call()
test_count_loop()
test_sum_loop()
test_tuple_fst()
test_tuple_second()
test_let_scalar()
test_let_tensor()
test_split()
test_split_no_fuse()
test_list_constructor()
test_let_tensor()
test_let_scalar()
test_compose()
test_list_hd()
test_list_tl_empty_list()
test_list_tl()
test_list_nth()
test_list_update()
test_list_length()
test_list_map()
test_list_foldl()
test_list_foldr()
test_list_sum()
test_list_filter()
test_closure()
test_add_op_scalar()
test_add_op_tensor()
test_add_op_broadcast()
...@@ -107,9 +107,9 @@ def test_serializer(): ...@@ -107,9 +107,9 @@ def test_serializer():
assert any(item.startswith('fused_multiply') for item in prim_ops) assert any(item.startswith('fused_multiply') for item in prim_ops)
code = exe.bytecode code = exe.bytecode
assert "main 8 2 8" in code assert "main(x1, y1)" in code
assert "f1 5 1 6" in code assert "f1(x)" in code
assert "f2 5 1 6" in code assert "f2(y)" in code
code, lib = exe.save() code, lib = exe.save()
assert isinstance(code, bytearray) assert isinstance(code, bytearray)
......
...@@ -28,7 +28,7 @@ def test_basic(): ...@@ -28,7 +28,7 @@ def test_basic():
ctx = tvm.cpu() ctx = tvm.cpu()
if not relay.profiler_vm.enabled(): if not relay.profiler_vm.enabled():
return return
exe = relay.profiler_vm.compile(mod, target, params=params) exe = relay.vm.compile(mod, target, params=params)
vm = relay.profiler_vm.VirtualMachineProfiler(exe) vm = relay.profiler_vm.VirtualMachineProfiler(exe)
vm.init(ctx) vm.init(ctx)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment