Unverified Commit 9cc78741 by Jared Roesch Committed by GitHub

[Relay][Params] Add APIs for storing and retrieving parameters from individual functions. (#4194)

* Add support for attaching params

* Fix types

* Fix test
parent 93d610a1
......@@ -274,6 +274,19 @@ class FunctionNode : public ExprNode {
tvm::Array<TypeVar> ty_params,
tvm::Attrs attrs = Attrs());
/*!
* \brief Attach the function's parameters to its attributes for use in analysis.
* \return The function with its parameters attached.
*/
Function SetParams(const tvm::Map<Var, Constant>& parameters) const;
/*!
* \brief Retrieve the function's parameters.
*
* \return The function's parameter.
*/
tvm::Map<Var, Constant> GetParams() const;
static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
};
......@@ -284,7 +297,6 @@ RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr);
TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key);
TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data);
/*!
* \brief Call corresponds to operator invocation.
* Corresponds to the operator in computational graph terminology.
......
......@@ -27,6 +27,7 @@ from . import ty as _ty
from .._ffi import base as _base
from .. import nd as _nd
from .. import convert
from ..ndarray import NDArray
# will be registered afterwards
_op_make = None
......@@ -305,6 +306,17 @@ class Function(Expr):
"""
return Call(self, args, None, None)
def get_params(self):
return _expr.FunctionGetParams(self)
def set_params(self, params):
for key in params:
value = params[key]
if isinstance(value, NDArray):
params[key] = Constant(value)
return _expr.FunctionSetParams(self, params)
@register_relay_node
class Call(Expr):
......
......@@ -159,6 +159,26 @@ bool FunctionNode::IsPrimitive() const {
return pval && pval->value != 0;
}
Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) const {
return FunctionSetAttr(GetRef<Function>(this), "__params__", parameters);
}
TVM_REGISTER_API("relay._expr.FunctionSetParams")
.set_body_typed<Function(const Function&, const tvm::Map<Var, Constant>&)>(
[](const Function& func, const tvm::Map<Var, Constant>& parameters) {
return func->SetParams(parameters);
});
tvm::Map<Var, Constant> FunctionNode::GetParams() const {
auto node_ref = FunctionGetAttr(GetRef<Function>(this), "__params__");
return Downcast<tvm::Map<Var, Constant>>(node_ref);
}
TVM_REGISTER_API("relay._expr.FunctionGetParams")
.set_body_typed<tvm::Map<Var, Constant>(const Function&)>([](const Function& func) {
return func->GetParams();
});
NodeRef FunctionGetAttr(const Function& func, const std::string& key) {
if (!func->attrs.defined()) { return NodeRef(); }
......
......@@ -20,7 +20,7 @@ from tvm import relay
from tvm.expr import *
from tvm.relay import op
from tvm.relay.analysis import graph_equal
import numpy as np
def check_json_roundtrip(node):
json_str = tvm.save_json(node)
......@@ -160,7 +160,6 @@ def test_global_var():
str(gv)
check_json_roundtrip(gv)
def test_function():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Var(n) for n in param_names])
......@@ -175,6 +174,34 @@ def test_function():
str(fn)
check_json_roundtrip(fn)
def test_function_attrs():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.var(n, shape=(5, 2)) for n in param_names])
ret_type = relay.TupleType(tvm.convert([]))
body = relay.Tuple(tvm.convert([]))
type_params = tvm.convert([])
fn = relay.Function(params, body, ret_type, type_params)
model_params = {}
for param in params[:1]:
cty = param.type_annotation
tensor = np.random.rand(*[int(sh) for sh in cty.shape]).astype(cty.dtype)
model_params[param] = tvm.nd.array(tensor)
fn = fn.set_params(model_params)
assert fn.params == params
assert fn.body == body
assert fn.type_params == type_params
assert fn.span == None
str(fn)
check_json_roundtrip(fn)
json_str = tvm.save_json(fn)
fn_after = tvm.load_json(json_str)
model_params_after = fn_after.get_params()
after_keys = [item[0] for item in model_params_after.items()]
for key1, key2 in zip(model_params, after_keys):
assert key1.name_hint == key2.name_hint
p1 = model_params[key1]
p2 = model_params_after[key2]
np.testing.assert_allclose(p1.data.asnumpy(), p2.data.asnumpy())
def test_call():
op = relay.Var('f')
......@@ -257,9 +284,11 @@ if __name__ == "__main__":
test_local_var()
test_global_var()
test_function()
test_function_attrs()
test_call()
test_let()
test_if()
test_tuple_get_item()
test_op()
test_conv2d_attrs()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment