Commit 94cc89da by Jared Roesch Committed by Haichen Shen

Add support for passing arguments by args and kwargs when using executor (#2402)

* Add support for passing arguments by args and kwargs when using executor

* Fix linting

* Update comment, and add arity checking

* Small tweak to error message
parent 794bf7fe
...@@ -73,6 +73,63 @@ def _arg_to_ast(arg): ...@@ -73,6 +73,63 @@ def _arg_to_ast(arg):
class Executor(object): class Executor(object):
"""An abstract interface for executing Relay programs.""" """An abstract interface for executing Relay programs."""
def _convert_args(self, expr, args, kwargs):
"""
Convert the combination of arguments and keyword arguments
into a sequence of arguments that may be passed to
a Relay evaluator.
We first provide all positional arguments, and then attempt
to fill in the remaining arguments using the keyword arguments. We
map the keyword arguments to the corresponding parameters, if there
is an ambiguity between positional and keyword arguments this
procedure will raise an error.
Parameters
----------
expr: relay.Expr
The expression to evaluate
args: List[tvm.NDArray]
The arguments to pass to the evaluator.
kwargs: Dict[str, tvm.NDArrray]
The keyword arguments to pass to the evaluator.
Returns:
args: List[tvm.NDArray]
The new arguments with all keyword arguments placed in the correct slot.
"""
if not kwargs:
return args
if kwargs and not isinstance(expr, Function):
raise Exception("can only supply keyword parameters for a \
relay.Function, found {0}".format(expr))
params = expr.params
param_names = [p.name_hint for p in params]
num_of_args = len(args)
cargs = list(args)[:]
for i, name in enumerate(param_names):
if i < num_of_args:
if kwargs.get(name):
raise Exception(
"duplicate argument supplied in \
both positional args (at position: {0}), \
and keyword argument (with name: {1})".format(i, name))
else:
cargs.append(kwargs[name])
if len(cargs) != len(params):
raise Exception(
"insufficient arguments, expected" \
" {0}, provided {1}".format(len(cargs), len(params)))
return tuple(cargs)
def _make_executor(self, _): def _make_executor(self, _):
""" """
Construct a Python function that implements the evaluation Construct a Python function that implements the evaluation
...@@ -166,7 +223,9 @@ class Interpreter(Executor): ...@@ -166,7 +223,9 @@ class Interpreter(Executor):
return ck_fused return ck_fused
def _make_executor(self, expr): def _make_executor(self, expr):
def _interp_wrapper(*args): def _interp_wrapper(*args, **kwargs):
args = self._convert_args(expr, args, kwargs)
relay_args = [] relay_args = []
for arg in args: for arg in args:
relay_args.append(_arg_to_ast(arg)) relay_args.append(_arg_to_ast(arg))
......
...@@ -269,7 +269,9 @@ class GraphExecutor(_interpreter.Executor): ...@@ -269,7 +269,9 @@ class GraphExecutor(_interpreter.Executor):
gmodule = _graph_rt.create(graph_json, mod, self.ctx) gmodule = _graph_rt.create(graph_json, mod, self.ctx)
if params: if params:
gmodule.set_input(*params) gmodule.set_input(*params)
def _graph_wrapper(*args):
def _graph_wrapper(*args, **kwargs):
args = self._convert_args(func, args, kwargs)
# Create map of inputs. # Create map of inputs.
for i, arg in enumerate(args): for i, arg in enumerate(args):
gmodule.set_input(i, arg) gmodule.set_input(i, arg)
......
...@@ -118,6 +118,18 @@ def test_binds(): ...@@ -118,6 +118,18 @@ def test_binds():
res = intrp.evaluate(y, binds={x: xx}).asnumpy() res = intrp.evaluate(y, binds={x: xx}).asnumpy()
tvm.testing.assert_allclose(xx + xx, res) tvm.testing.assert_allclose(xx + xx, res)
def test_kwargs_params():
x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(1, 10))
z = relay.var("z", shape=(1, 10))
f = relay.Function([x, y, z], x + y + z)
x_data = np.random.rand(1, 10).astype('float32')
y_data = np.random.rand(1, 10).astype('float32')
z_data = np.random.rand(1, 10).astype('float32')
params = { 'y': y_data, 'z': z_data }
intrp = create_executor("debug")
res = intrp.evaluate(f)(x_data, **params).data
tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data)
if __name__ == "__main__": if __name__ == "__main__":
test_id() test_id()
...@@ -127,3 +139,4 @@ if __name__ == "__main__": ...@@ -127,3 +139,4 @@ if __name__ == "__main__":
test_simple_loop() test_simple_loop()
test_loop() test_loop()
test_binds() test_binds()
test_kwargs_params()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment