Commit c48812dd by Jared Roesch Committed by Tianqi Chen

[RELAY][RUNTIME] Refactor interpreter and graph_runtime into consistent interface. (#2042)

parent 0319f99d
......@@ -7,6 +7,8 @@ from . import ty
from . import expr
from . import env
from . import ir_pass
from .build_module import build
from .interpreter import create_executor
# Root operators
from .op import Op
......
"""
Construct the necessary state for the TVM graph runtime
from a Relay expression.
"""
from ..build_module import build as tvm_build_module
from . graph_runtime_codegen import GraphRuntimeCodegen
from . import ir_pass
from .env import Environment
def build(func, params=None, target=None, env=None):
"""
Compile a single function to the components needed by the
TVM RTS.
Parameters
----------
func: relay.Expr
The function to build.
target: optional str
The target platform.
Returns
-------
(graph_json, mod, params): tuple of (str, tvm.Module, dict)
The outputs of building a Relay function for the TVM runtime.
"""
if target is None:
target = 'llvm'
if env is None:
env = Environment({})
comp = GraphRuntimeCodegen(env)
# NB(@jroesch) This creates lowered functions, and generates names for them
#
# We need these names to emit the correct graph as these are names of the
# functions contained in the module.
lowered_ops = ir_pass.lower_ops(env, func)
mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target)
# Therefore the call to compile must come after.
comp.codegen(func)
graph_json = comp.to_json()
return graph_json, mod, params
......@@ -319,6 +319,118 @@ class TupleGetItem(Expr):
self.__init_handle_by_constructor__(
_make.TupleGetItem, tuple_value, index)
class ExprFunctor(object):
"""
An abstract visitor defined over Expr.
A Python version of the class defined in `expr_functor.h`.
Defines the default dispatch over expressions, and
implements memoization.
"""
def __init__(self):
self.memo_map = {}
# pylint: disable=no-else-return
def visit(self, expr):
"""Apply the visitor to an expression."""
found = self.memo_map.get(expr)
if found:
return found
if isinstance(expr, Function):
res = self.visit_function(expr)
elif isinstance(expr, Call):
res = self.visit_call(expr)
elif isinstance(expr, Let):
res = self.visit_let(expr)
elif isinstance(expr, Var):
res = self.visit_var(expr)
elif isinstance(expr, GlobalVar):
res = self.visit_global_var(expr)
elif isinstance(expr, If):
res = self.visit_if(expr)
elif isinstance(expr, Tuple):
res = self.visit_tuple(expr)
elif isinstance(expr, Constant):
res = self.visit_constant(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))
self.memo_map[expr] = res
return res
def visit_function(self, _):
raise Exception("Abstract method please implement me.")
def visit_let(self, _):
raise Exception("Abstract method please implement me.")
def visit_call(self, _):
raise Exception("Abstract method please implement me.")
def visit_var(self, _):
raise Exception("Abstract method please implement me.")
def visit_type(self, typ):
return typ
def visit_if(self, _):
raise Exception("Abstract method please implement me.")
def visit_tuple(self, _):
raise Exception("Abstract method please implement me.")
def visit_constant(self, _):
raise Exception("Abstract method please implement me.")
def visit_global_var(self, _):
raise Exception("Abstract method please implement me.")
class ExprMutator(ExprFunctor):
"""
A functional visitor over Expr.
The default behavior recursively traverses the AST
and reconstructs the AST.
"""
def visit_function(self, fn):
new_body = self.visit(fn.body)
return Function(
list(fn.params),
fn.ret_type, new_body,
fn.type_params)
def visit_let(self, let):
new_var = self.visit(let.var)
new_val = self.visit(let.value)
new_body = self.visit(let.body)
return Let(new_var, new_val, new_body)
def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return Call(new_fn, new_args, call.attrs)
def visit_var(self, rvar):
return rvar
def visit_global_id(self, global_var):
return global_var
def visit_if(self, ite):
return If(
self.visit(ite.guard),
self.visit(ite.true_b),
self.visit(ite.false_b))
def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields])
def visit_constant(self, rconst):
return rconst
class TupleWrapper(object):
"""TupleWrapper.
......
......@@ -25,113 +25,7 @@ import json
import attr
from . import ir_pass
from .op import Op
from .expr import Var, Function, Call, If, GlobalVar, Constant, Let, Tuple
from ..build_module import build as tvm_build_module
from .. contrib import graph_runtime
from .ir_pass import infer_type
from .. import cpu
class AbstractExprVisitor(object):
"""A visitor over Expr in Python."""
def __init__(self):
self.memo_map = {}
# pylint: disable=no-else-return
def visit(self, expr):
"""Apply the visitor to an expression."""
found = self.memo_map.get(expr)
if found:
return found
if isinstance(expr, Function):
res = self.visit_function(expr)
elif isinstance(expr, Call):
res = self.visit_call(expr)
elif isinstance(expr, Let):
res = self.visit_let(expr)
elif isinstance(expr, Var):
res = self.visit_var(expr)
elif isinstance(expr, GlobalVar):
res = self.visit_global_var(expr)
elif isinstance(expr, If):
res = self.visit_if(expr)
elif isinstance(expr, Tuple):
res = self.visit_tuple(expr)
elif isinstance(expr, Constant):
res = self.visit_constant(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))
self.memo_map[expr] = res
return res
def visit_function(self, _):
raise Exception("Abstract method please implement me.")
def visit_let(self, _):
raise Exception("Abstract method please implement me.")
def visit_call(self, _):
raise Exception("Abstract method please implement me.")
def visit_var(self, _):
raise Exception("Abstract method please implement me.")
def visit_type(self, typ):
return typ
def visit_if(self, _):
raise Exception("Abstract method please implement me.")
def visit_tuple(self, _):
raise Exception("Abstract method please implement me.")
def visit_constant(self, _):
raise Exception("Abstract method please implement me.")
def visit_global_var(self, _):
raise Exception("Abstract method please implement me.")
class ExprMutator(AbstractExprVisitor):
"""A functional visitor over Expr in Python."""
def visit_function(self, fn):
new_body = self.visit(fn.body)
return Function(
list(fn.params),
fn.ret_type, new_body,
fn.type_params)
def visit_let(self, let):
new_var = self.visit(let.var)
new_val = self.visit(let.value)
new_body = self.visit(let.body)
return Let(new_var, new_val, new_body)
def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return Call(new_fn, new_args, call.attrs)
def visit_var(self, var):
return var
def visit_global_id(self, global_var):
return global_var
def visit_if(self, ite):
return If(
self.visit(ite.guard),
self.visit(ite.true_b),
self.visit(ite.false_b))
def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields])
def visit_constant(self, const):
return const
from .expr import Function, GlobalVar, ExprMutator
@attr.s
......@@ -359,8 +253,8 @@ class GraphRuntimeCodegen(ExprMutator):
self.add_binding(ident, val_ref)
return self.visit(body)
def visit_var(self, var):
return self.lookup(var)
def visit_var(self, rvar):
return self.lookup(rvar)
def visit_call(self, call):
"""Transform a ::tvm.relay.Call into an operator in the TVM graph."""
......@@ -472,80 +366,3 @@ class GraphRuntimeCodegen(ExprMutator):
}
return json.dumps(json_dict)
def build(env, func, target=None):
"""
Compile a single function to the components needed by the
TVM RTS.
Parameters
----------
func: relay.Expr
The function to build.
target: optional str
The target platform.
Returns
-------
(graph_json, mod, params): tuple of (str, tvm.Module, dict)
The outputs of building a Relay function for the TVM runtime.
"""
if target is None:
target = 'llvm'
comp = GraphRuntimeCodegen(env)
# NB(@jroesch) This creates lowered functions, and generates names for them
#
# We need these names to emit the correct graph as these are names of the
# functions contained in the module.
lowered_ops = ir_pass.lower_ops(env, func)
mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target)
# Therefore the call to compile must come after.
comp.codegen(func)
graph_json = comp.to_json()
return graph_json, mod, None # params currently isn't supported by API
def graph_evaluate(env, func, *args):
"""
Corresponding function to tvm.relay.eval.evaluate.
This function evaluates a Relay expression on the
TVM graph_runtime.
Parameters
----------
env: tvm.relay.Environment
The global environment used.
expr: tvm.relay.Expr
The expression to evaluate.
args: list of tvm.relay.Expr
The arguments to apply to the expression, only works
if the expression has a function type.
Returns
-------
value: tvm.NDArray
The output Tensor produced by evaluating the expression.
"""
func = infer_type(func, env)
func = ir_pass.fuse_ops(env, func)
func = infer_type(func, env)
graph_json, mod, params = build(env, func)
assert params is None
gmodule = graph_runtime.create(graph_json, mod, cpu(0))
# Create map of inputs.
inputs = {}
for i, arg in enumerate(args):
inputs[func.params[i].name_hint] = arg
# Set the inputs here.
gmodule.set_input(**inputs)
# Run the module, and fetch the output.
gmodule.run()
return gmodule.get_output(0)
......@@ -4,12 +4,16 @@ from __future__ import absolute_import
import numpy as np
from .. import register_func, nd
from .base import NodeBase, register_relay_node
from . import build_module
from . import _make
from . import _interpreter
from . import ir_pass
from .expr import Call, Constant, GlobalVar
from . import const
from .env import Environment
from .expr import Call, Constant, GlobalVar, Function, const
from .scope_builder import ScopeBuilder
from .._ffi.base import integer_types
from ..contrib import graph_runtime as tvm_runtime
from .. import cpu
class Value(NodeBase):
"""Base class of all values.
......@@ -83,48 +87,122 @@ def _arg_to_ast(arg):
else:
return const(arg)
class Executor(object):
"""An abstract interface for executing Relay programs."""
def __init__(self, env=None):
"""
Parameters
----------
env: relay.Environment
The environment.
"""
if env is None:
self.env = Environment({})
else:
self.env = env
def apply_passes(expr, env=None):
ck_expr = ir_pass.infer_type(expr, env=env)
fused_expr = ir_pass.fuse_ops(env, ck_expr)
return fused_expr
def optimize(self, expr):
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, env=self.env)
fused_expr = ir_pass.fuse_ops(self.env, ck_expr)
ck_fused = ir_pass.infer_type(fused_expr, env=self.env)
return ck_fused
def evaluate(env, expr, *args):
def _make_executor(self, _):
"""
Evaluate a Relay expression on the interpreter.
Construct a Python function that implements the evaluation
of expression.
Parameters
----------
env: tvm.relay.Environment
The global environment used.
expr: relay.Expr
The Relay expression to execute.
Returns
-------
executor: function
A Python function which implements the behavior of `expr`.
"""
raise Exception("abstract method: please implement me.")
def evaluate(self, expr, params=None):
"""
Evaluate a Relay expression on the interpreter.
Parameters
----------
expr: tvm.relay.Expr
The expression to evaluate.
"""
if params:
scope_builder = ScopeBuilder()
for key, value in params:
scope_builder.let(key, value)
scope_builder.ret(expr)
expr = scope_builder.get()
args: list of tvm.relay.Expr
The arguments to apply to the expression, only works
if the expression has a function type.
if isinstance(expr, Function):
assert not ir_pass.free_vars(expr)
Returns
-------
value: tvm.relay.eval.Value
The value produced by evaluating the expression.
return self._make_executor(expr)
class Interpreter(Executor):
"""
A wrapper around the Relay interpreter, implements the excecutor interface.
"""
# assert len(args) == 0
def __init__(self, env=None):
Executor.__init__(self, env)
def _make_executor(self, expr):
def _interp_wrapper(*args):
relay_args = []
for arg in args:
relay_args.append(_arg_to_ast(arg))
# TODO: We need to move this optimization code into the optimizer/pass manager
if isinstance(expr, GlobalVar):
func = env[expr]
func = apply_passes(func, env)
env._add(expr, func, True)
func = self.env[expr]
func = self.optimize(func)
self.env._add(expr, func, True)
opt_expr = Call(expr, relay_args)
# import pdb; pdb.set_trace()
return _interpreter.evaluate(env, opt_expr)
return _interpreter.evaluate(self.env, opt_expr)
else:
call = Call(expr, relay_args)
opt_expr = self.optimize(call)
return _interpreter.evaluate(self.env, opt_expr)
return _interp_wrapper
class GraphRuntime(Executor):
"""A wrapper around the TVM graph runtime, implements the Executor interface."""
def __init__(self, env=None):
Executor.__init__(self, env)
def _make_executor(self, expr):
def _graph_wrapper(*args):
func = self.optimize(expr)
graph_json, mod, params = build_module.build(func, env=self.env)
assert params is None
gmodule = tvm_runtime.create(graph_json, mod, cpu(0))
# Create map of inputs.
inputs = {}
for i, arg in enumerate(args):
inputs[func.params[i].name_hint] = arg
# Set the inputs here.
gmodule.set_input(**inputs)
# Run the module, and fetch the output.
gmodule.run()
return gmodule.get_output(0)
return _graph_wrapper
def create_executor(mode='debug', env=None):
if mode == 'debug':
return Interpreter(env)
elif mode == 'graph':
return GraphRuntime(env)
else:
expr = Call(expr, relay_args)
opt_expr = apply_passes(expr, env)
return _interpreter.evaluate(env, opt_expr)
raise Exception("unknown mode {0}".format(mode))
import numpy as np
from tvm import relay
from tvm.relay import create_executor
from tvm.relay.ir_pass import infer_type
from tvm.relay.interpreter import evaluate
from tvm.relay.graph_runtime_codegen import graph_evaluate
from tvm.relay.interpreter import Interpreter
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.op import add
from tvm.relay.env import Environment
# @tq, @jr should we put this in testing ns?
def check_rts(env, expr, args, expected_result):
def check_rts(expr, args, expected_result, env=None):
"""
Check that evaluating `expr` applied to the arguments produces
`result` on both the evaluator and TVM runtime.
......@@ -25,8 +25,10 @@ def check_rts(env, expr, args, expected_result):
expected_result:
The expected result of running the expression.
"""
eval_result = evaluate(env, expr, *args)
rts_result = graph_evaluate(env, expr, *args)
intrp = create_executor('graph', env=env)
graph = create_executor('graph', env=env)
eval_result = intrp.evaluate(expr)(*args)
rts_result = graph.evaluate(expr)(*args)
np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
def test_add_op_scalar():
......@@ -36,13 +38,12 @@ def test_add_op_scalar():
return x + y;
}
"""
env = Environment()
x = relay.var('x', shape=())
y = relay.var('y', shape=())
func = relay.Function([x, y], add(x, y))
x_data = np.array(10.0, dtype='float32')
y_data = np.array(1.0, dtype='float32')
check_rts(env, func, [x_data, y_data], x_data + y_data)
check_rts(func, [x_data, y_data], x_data + y_data)
def test_add_op_tensor():
"""
......@@ -51,13 +52,12 @@ def test_add_op_tensor():
return x + y;
}
"""
env = Environment()
x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(10, 5))
func = relay.Function([x, y], add(x, y))
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(10, 5).astype('float32')
check_rts(env, func, [x_data, y_data], x_data + y_data)
check_rts(func, [x_data, y_data], x_data + y_data)
def test_add_op_broadcast():
"""
......@@ -66,13 +66,12 @@ def test_add_op_broadcast():
return x + y;
}
"""
env = Environment()
x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(1, 5))
func = relay.Function([x, y], add(x, y))
x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(1, 5).astype('float32')
check_rts(env, func, [x_data, y_data], x_data + y_data)
check_rts(func, [x_data, y_data], x_data + y_data)
if __name__ == "__main__":
test_add_op_scalar()
......
import numpy as np
import tvm
from tvm import relay
from tvm.relay.interpreter import Value, TupleValue, evaluate
from tvm.relay.interpreter import Value, TupleValue
from tvm.relay import op
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing
from tvm.relay import testing, create_executor
def check_eval(expr, args, expected_result, env=None, rtol=1e-07):
if env is None:
env = relay.env.Environment({})
result = evaluate(env, expr, *args)
intrp = create_executor(env=env)
result = intrp.evaluate(expr)(*args)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
......@@ -32,8 +30,6 @@ def test_tuple_value():
def test_id():
x = relay.var('x', 'float32')
ident = relay.Function([x], x)
env = relay.env.Environment({})
res = evaluate(env, ident, 1.0)
check_eval(ident, [1.0], 1.0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment