Commit c48812dd by Jared Roesch Committed by Tianqi Chen

[RELAY][RUNTIME] Refactor interpreter and graph_runtime into consistent interface. (#2042)

parent 0319f99d
...@@ -7,6 +7,8 @@ from . import ty ...@@ -7,6 +7,8 @@ from . import ty
from . import expr from . import expr
from . import env from . import env
from . import ir_pass from . import ir_pass
from .build_module import build
from .interpreter import create_executor
# Root operators # Root operators
from .op import Op from .op import Op
......
"""
Construct the necessary state for the TVM graph runtime
from a Relay expression.
"""
from ..build_module import build as tvm_build_module
from . graph_runtime_codegen import GraphRuntimeCodegen
from . import ir_pass
from .env import Environment
def build(func, params=None, target=None, env=None):
"""
Compile a single function to the components needed by the
TVM RTS.
Parameters
----------
func: relay.Expr
The function to build.
target: optional str
The target platform.
Returns
-------
(graph_json, mod, params): tuple of (str, tvm.Module, dict)
The outputs of building a Relay function for the TVM runtime.
"""
if target is None:
target = 'llvm'
if env is None:
env = Environment({})
comp = GraphRuntimeCodegen(env)
# NB(@jroesch) This creates lowered functions, and generates names for them
#
# We need these names to emit the correct graph as these are names of the
# functions contained in the module.
lowered_ops = ir_pass.lower_ops(env, func)
mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target)
# Therefore the call to compile must come after.
comp.codegen(func)
graph_json = comp.to_json()
return graph_json, mod, params
...@@ -319,6 +319,118 @@ class TupleGetItem(Expr): ...@@ -319,6 +319,118 @@ class TupleGetItem(Expr):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.TupleGetItem, tuple_value, index) _make.TupleGetItem, tuple_value, index)
class ExprFunctor(object):
"""
An abstract visitor defined over Expr.
A Python version of the class defined in `expr_functor.h`.
Defines the default dispatch over expressions, and
implements memoization.
"""
def __init__(self):
self.memo_map = {}
# pylint: disable=no-else-return
def visit(self, expr):
"""Apply the visitor to an expression."""
found = self.memo_map.get(expr)
if found:
return found
if isinstance(expr, Function):
res = self.visit_function(expr)
elif isinstance(expr, Call):
res = self.visit_call(expr)
elif isinstance(expr, Let):
res = self.visit_let(expr)
elif isinstance(expr, Var):
res = self.visit_var(expr)
elif isinstance(expr, GlobalVar):
res = self.visit_global_var(expr)
elif isinstance(expr, If):
res = self.visit_if(expr)
elif isinstance(expr, Tuple):
res = self.visit_tuple(expr)
elif isinstance(expr, Constant):
res = self.visit_constant(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))
self.memo_map[expr] = res
return res
def visit_function(self, _):
raise Exception("Abstract method please implement me.")
def visit_let(self, _):
raise Exception("Abstract method please implement me.")
def visit_call(self, _):
raise Exception("Abstract method please implement me.")
def visit_var(self, _):
raise Exception("Abstract method please implement me.")
def visit_type(self, typ):
return typ
def visit_if(self, _):
raise Exception("Abstract method please implement me.")
def visit_tuple(self, _):
raise Exception("Abstract method please implement me.")
def visit_constant(self, _):
raise Exception("Abstract method please implement me.")
def visit_global_var(self, _):
raise Exception("Abstract method please implement me.")
class ExprMutator(ExprFunctor):
"""
A functional visitor over Expr.
The default behavior recursively traverses the AST
and reconstructs the AST.
"""
def visit_function(self, fn):
new_body = self.visit(fn.body)
return Function(
list(fn.params),
fn.ret_type, new_body,
fn.type_params)
def visit_let(self, let):
new_var = self.visit(let.var)
new_val = self.visit(let.value)
new_body = self.visit(let.body)
return Let(new_var, new_val, new_body)
def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return Call(new_fn, new_args, call.attrs)
def visit_var(self, rvar):
return rvar
def visit_global_id(self, global_var):
return global_var
def visit_if(self, ite):
return If(
self.visit(ite.guard),
self.visit(ite.true_b),
self.visit(ite.false_b))
def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields])
def visit_constant(self, rconst):
return rconst
class TupleWrapper(object): class TupleWrapper(object):
"""TupleWrapper. """TupleWrapper.
......
...@@ -25,113 +25,7 @@ import json ...@@ -25,113 +25,7 @@ import json
import attr import attr
from . import ir_pass from . import ir_pass
from .op import Op from .op import Op
from .expr import Var, Function, Call, If, GlobalVar, Constant, Let, Tuple from .expr import Function, GlobalVar, ExprMutator
from ..build_module import build as tvm_build_module
from .. contrib import graph_runtime
from .ir_pass import infer_type
from .. import cpu
class AbstractExprVisitor(object):
"""A visitor over Expr in Python."""
def __init__(self):
self.memo_map = {}
# pylint: disable=no-else-return
def visit(self, expr):
"""Apply the visitor to an expression."""
found = self.memo_map.get(expr)
if found:
return found
if isinstance(expr, Function):
res = self.visit_function(expr)
elif isinstance(expr, Call):
res = self.visit_call(expr)
elif isinstance(expr, Let):
res = self.visit_let(expr)
elif isinstance(expr, Var):
res = self.visit_var(expr)
elif isinstance(expr, GlobalVar):
res = self.visit_global_var(expr)
elif isinstance(expr, If):
res = self.visit_if(expr)
elif isinstance(expr, Tuple):
res = self.visit_tuple(expr)
elif isinstance(expr, Constant):
res = self.visit_constant(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))
self.memo_map[expr] = res
return res
def visit_function(self, _):
raise Exception("Abstract method please implement me.")
def visit_let(self, _):
raise Exception("Abstract method please implement me.")
def visit_call(self, _):
raise Exception("Abstract method please implement me.")
def visit_var(self, _):
raise Exception("Abstract method please implement me.")
def visit_type(self, typ):
return typ
def visit_if(self, _):
raise Exception("Abstract method please implement me.")
def visit_tuple(self, _):
raise Exception("Abstract method please implement me.")
def visit_constant(self, _):
raise Exception("Abstract method please implement me.")
def visit_global_var(self, _):
raise Exception("Abstract method please implement me.")
class ExprMutator(AbstractExprVisitor):
"""A functional visitor over Expr in Python."""
def visit_function(self, fn):
new_body = self.visit(fn.body)
return Function(
list(fn.params),
fn.ret_type, new_body,
fn.type_params)
def visit_let(self, let):
new_var = self.visit(let.var)
new_val = self.visit(let.value)
new_body = self.visit(let.body)
return Let(new_var, new_val, new_body)
def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
return Call(new_fn, new_args, call.attrs)
def visit_var(self, var):
return var
def visit_global_id(self, global_var):
return global_var
def visit_if(self, ite):
return If(
self.visit(ite.guard),
self.visit(ite.true_b),
self.visit(ite.false_b))
def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields])
def visit_constant(self, const):
return const
@attr.s @attr.s
...@@ -359,8 +253,8 @@ class GraphRuntimeCodegen(ExprMutator): ...@@ -359,8 +253,8 @@ class GraphRuntimeCodegen(ExprMutator):
self.add_binding(ident, val_ref) self.add_binding(ident, val_ref)
return self.visit(body) return self.visit(body)
def visit_var(self, var): def visit_var(self, rvar):
return self.lookup(var) return self.lookup(rvar)
def visit_call(self, call): def visit_call(self, call):
"""Transform a ::tvm.relay.Call into an operator in the TVM graph.""" """Transform a ::tvm.relay.Call into an operator in the TVM graph."""
...@@ -472,80 +366,3 @@ class GraphRuntimeCodegen(ExprMutator): ...@@ -472,80 +366,3 @@ class GraphRuntimeCodegen(ExprMutator):
} }
return json.dumps(json_dict) return json.dumps(json_dict)
def build(env, func, target=None):
"""
Compile a single function to the components needed by the
TVM RTS.
Parameters
----------
func: relay.Expr
The function to build.
target: optional str
The target platform.
Returns
-------
(graph_json, mod, params): tuple of (str, tvm.Module, dict)
The outputs of building a Relay function for the TVM runtime.
"""
if target is None:
target = 'llvm'
comp = GraphRuntimeCodegen(env)
# NB(@jroesch) This creates lowered functions, and generates names for them
#
# We need these names to emit the correct graph as these are names of the
# functions contained in the module.
lowered_ops = ir_pass.lower_ops(env, func)
mod = tvm_build_module([lf.lowered_func for lf in lowered_ops], target)
# Therefore the call to compile must come after.
comp.codegen(func)
graph_json = comp.to_json()
return graph_json, mod, None # params currently isn't supported by API
def graph_evaluate(env, func, *args):
"""
Corresponding function to tvm.relay.eval.evaluate.
This function evaluates a Relay expression on the
TVM graph_runtime.
Parameters
----------
env: tvm.relay.Environment
The global environment used.
expr: tvm.relay.Expr
The expression to evaluate.
args: list of tvm.relay.Expr
The arguments to apply to the expression, only works
if the expression has a function type.
Returns
-------
value: tvm.NDArray
The output Tensor produced by evaluating the expression.
"""
func = infer_type(func, env)
func = ir_pass.fuse_ops(env, func)
func = infer_type(func, env)
graph_json, mod, params = build(env, func)
assert params is None
gmodule = graph_runtime.create(graph_json, mod, cpu(0))
# Create map of inputs.
inputs = {}
for i, arg in enumerate(args):
inputs[func.params[i].name_hint] = arg
# Set the inputs here.
gmodule.set_input(**inputs)
# Run the module, and fetch the output.
gmodule.run()
return gmodule.get_output(0)
...@@ -4,12 +4,16 @@ from __future__ import absolute_import ...@@ -4,12 +4,16 @@ from __future__ import absolute_import
import numpy as np import numpy as np
from .. import register_func, nd from .. import register_func, nd
from .base import NodeBase, register_relay_node from .base import NodeBase, register_relay_node
from . import build_module
from . import _make from . import _make
from . import _interpreter from . import _interpreter
from . import ir_pass from . import ir_pass
from .expr import Call, Constant, GlobalVar from .env import Environment
from . import const from .expr import Call, Constant, GlobalVar, Function, const
from .scope_builder import ScopeBuilder
from .._ffi.base import integer_types from .._ffi.base import integer_types
from ..contrib import graph_runtime as tvm_runtime
from .. import cpu
class Value(NodeBase): class Value(NodeBase):
"""Base class of all values. """Base class of all values.
...@@ -83,48 +87,122 @@ def _arg_to_ast(arg): ...@@ -83,48 +87,122 @@ def _arg_to_ast(arg):
else: else:
return const(arg) return const(arg)
class Executor(object):
"""An abstract interface for executing Relay programs."""
def __init__(self, env=None):
"""
Parameters
----------
env: relay.Environment
The environment.
"""
if env is None:
self.env = Environment({})
else:
self.env = env
def apply_passes(expr, env=None):
ck_expr = ir_pass.infer_type(expr, env=env)
fused_expr = ir_pass.fuse_ops(env, ck_expr)
return fused_expr
def optimize(self, expr):
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr = ir_pass.infer_type(expr, env=self.env)
fused_expr = ir_pass.fuse_ops(self.env, ck_expr)
ck_fused = ir_pass.infer_type(fused_expr, env=self.env)
return ck_fused
def evaluate(env, expr, *args): def _make_executor(self, _):
""" """
Evaluate a Relay expression on the interpreter. Construct a Python function that implements the evaluation
of expression.
Parameters Parameters
---------- ----------
env: tvm.relay.Environment expr: relay.Expr
The global environment used. The Relay expression to execute.
Returns
-------
executor: function
A Python function which implements the behavior of `expr`.
"""
raise Exception("abstract method: please implement me.")
def evaluate(self, expr, params=None):
"""
Evaluate a Relay expression on the interpreter.
Parameters
----------
expr: tvm.relay.Expr expr: tvm.relay.Expr
The expression to evaluate. The expression to evaluate.
"""
if params:
scope_builder = ScopeBuilder()
for key, value in params:
scope_builder.let(key, value)
scope_builder.ret(expr)
expr = scope_builder.get()
args: list of tvm.relay.Expr if isinstance(expr, Function):
The arguments to apply to the expression, only works assert not ir_pass.free_vars(expr)
if the expression has a function type.
Returns return self._make_executor(expr)
-------
value: tvm.relay.eval.Value
The value produced by evaluating the expression. class Interpreter(Executor):
"""
A wrapper around the Relay interpreter, implements the excecutor interface.
""" """
# assert len(args) == 0 def __init__(self, env=None):
Executor.__init__(self, env)
def _make_executor(self, expr):
def _interp_wrapper(*args):
relay_args = [] relay_args = []
for arg in args: for arg in args:
relay_args.append(_arg_to_ast(arg)) relay_args.append(_arg_to_ast(arg))
# TODO: We need to move this optimization code into the optimizer/pass manager
if isinstance(expr, GlobalVar): if isinstance(expr, GlobalVar):
func = env[expr] func = self.env[expr]
func = apply_passes(func, env) func = self.optimize(func)
env._add(expr, func, True) self.env._add(expr, func, True)
opt_expr = Call(expr, relay_args) opt_expr = Call(expr, relay_args)
# import pdb; pdb.set_trace() return _interpreter.evaluate(self.env, opt_expr)
return _interpreter.evaluate(env, opt_expr) else:
call = Call(expr, relay_args)
opt_expr = self.optimize(call)
return _interpreter.evaluate(self.env, opt_expr)
return _interp_wrapper
class GraphRuntime(Executor):
"""A wrapper around the TVM graph runtime, implements the Executor interface."""
def __init__(self, env=None):
Executor.__init__(self, env)
def _make_executor(self, expr):
def _graph_wrapper(*args):
func = self.optimize(expr)
graph_json, mod, params = build_module.build(func, env=self.env)
assert params is None
gmodule = tvm_runtime.create(graph_json, mod, cpu(0))
# Create map of inputs.
inputs = {}
for i, arg in enumerate(args):
inputs[func.params[i].name_hint] = arg
# Set the inputs here.
gmodule.set_input(**inputs)
# Run the module, and fetch the output.
gmodule.run()
return gmodule.get_output(0)
return _graph_wrapper
def create_executor(mode='debug', env=None):
if mode == 'debug':
return Interpreter(env)
elif mode == 'graph':
return GraphRuntime(env)
else: else:
expr = Call(expr, relay_args) raise Exception("unknown mode {0}".format(mode))
opt_expr = apply_passes(expr, env)
return _interpreter.evaluate(env, opt_expr)
import numpy as np import numpy as np
from tvm import relay from tvm import relay
from tvm.relay import create_executor
from tvm.relay.ir_pass import infer_type from tvm.relay.ir_pass import infer_type
from tvm.relay.interpreter import evaluate from tvm.relay.interpreter import Interpreter
from tvm.relay.graph_runtime_codegen import graph_evaluate
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay.op import add from tvm.relay.op import add
from tvm.relay.env import Environment from tvm.relay.env import Environment
# @tq, @jr should we put this in testing ns? # @tq, @jr should we put this in testing ns?
def check_rts(env, expr, args, expected_result): def check_rts(expr, args, expected_result, env=None):
""" """
Check that evaluating `expr` applied to the arguments produces Check that evaluating `expr` applied to the arguments produces
`result` on both the evaluator and TVM runtime. `result` on both the evaluator and TVM runtime.
...@@ -25,8 +25,10 @@ def check_rts(env, expr, args, expected_result): ...@@ -25,8 +25,10 @@ def check_rts(env, expr, args, expected_result):
expected_result: expected_result:
The expected result of running the expression. The expected result of running the expression.
""" """
eval_result = evaluate(env, expr, *args) intrp = create_executor('graph', env=env)
rts_result = graph_evaluate(env, expr, *args) graph = create_executor('graph', env=env)
eval_result = intrp.evaluate(expr)(*args)
rts_result = graph.evaluate(expr)(*args)
np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy()) np.testing.assert_allclose(eval_result.asnumpy(), rts_result.asnumpy())
def test_add_op_scalar(): def test_add_op_scalar():
...@@ -36,13 +38,12 @@ def test_add_op_scalar(): ...@@ -36,13 +38,12 @@ def test_add_op_scalar():
return x + y; return x + y;
} }
""" """
env = Environment()
x = relay.var('x', shape=()) x = relay.var('x', shape=())
y = relay.var('y', shape=()) y = relay.var('y', shape=())
func = relay.Function([x, y], add(x, y)) func = relay.Function([x, y], add(x, y))
x_data = np.array(10.0, dtype='float32') x_data = np.array(10.0, dtype='float32')
y_data = np.array(1.0, dtype='float32') y_data = np.array(1.0, dtype='float32')
check_rts(env, func, [x_data, y_data], x_data + y_data) check_rts(func, [x_data, y_data], x_data + y_data)
def test_add_op_tensor(): def test_add_op_tensor():
""" """
...@@ -51,13 +52,12 @@ def test_add_op_tensor(): ...@@ -51,13 +52,12 @@ def test_add_op_tensor():
return x + y; return x + y;
} }
""" """
env = Environment()
x = relay.var('x', shape=(10, 5)) x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(10, 5)) y = relay.var('y', shape=(10, 5))
func = relay.Function([x, y], add(x, y)) func = relay.Function([x, y], add(x, y))
x_data = np.random.rand(10, 5).astype('float32') x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(10, 5).astype('float32') y_data = np.random.rand(10, 5).astype('float32')
check_rts(env, func, [x_data, y_data], x_data + y_data) check_rts(func, [x_data, y_data], x_data + y_data)
def test_add_op_broadcast(): def test_add_op_broadcast():
""" """
...@@ -66,13 +66,12 @@ def test_add_op_broadcast(): ...@@ -66,13 +66,12 @@ def test_add_op_broadcast():
return x + y; return x + y;
} }
""" """
env = Environment()
x = relay.var('x', shape=(10, 5)) x = relay.var('x', shape=(10, 5))
y = relay.var('y', shape=(1, 5)) y = relay.var('y', shape=(1, 5))
func = relay.Function([x, y], add(x, y)) func = relay.Function([x, y], add(x, y))
x_data = np.random.rand(10, 5).astype('float32') x_data = np.random.rand(10, 5).astype('float32')
y_data = np.random.rand(1, 5).astype('float32') y_data = np.random.rand(1, 5).astype('float32')
check_rts(env, func, [x_data, y_data], x_data + y_data) check_rts(func, [x_data, y_data], x_data + y_data)
if __name__ == "__main__": if __name__ == "__main__":
test_add_op_scalar() test_add_op_scalar()
......
import numpy as np import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.interpreter import Value, TupleValue, evaluate from tvm.relay.interpreter import Value, TupleValue
from tvm.relay import op from tvm.relay import op
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing from tvm.relay import testing, create_executor
def check_eval(expr, args, expected_result, env=None, rtol=1e-07): def check_eval(expr, args, expected_result, env=None, rtol=1e-07):
if env is None: intrp = create_executor(env=env)
env = relay.env.Environment({}) result = intrp.evaluate(expr)(*args)
result = evaluate(env, expr, *args)
np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol)
...@@ -32,8 +30,6 @@ def test_tuple_value(): ...@@ -32,8 +30,6 @@ def test_tuple_value():
def test_id(): def test_id():
x = relay.var('x', 'float32') x = relay.var('x', 'float32')
ident = relay.Function([x], x) ident = relay.Function([x], x)
env = relay.env.Environment({})
res = evaluate(env, ident, 1.0)
check_eval(ident, [1.0], 1.0) check_eval(ident, [1.0], 1.0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment