Unverified Commit 41835d17 by Jon Soifer Committed by GitHub

[Relay] Expose FunctionGetAttr to Python (#4905)

* [Relay] Expose FunctionGetAttr to Python

* add test

Co-authored-by: Jon Soifer <jonso@microsoft.com>
parent 9d646543
......@@ -280,6 +280,9 @@ class Function(BaseFunc):
def set_attribute(self, name, ref):
return _expr.FunctionSetAttr(self, name, ref)
def get_attribute(self, name):
return _expr.FunctionGetAttr(self, name)
@register_relay_node
class Call(ExprWithOp):
......
......@@ -360,6 +360,12 @@ TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr")
return FunctionSetAttr(func, name, ref);
});
TVM_REGISTER_GLOBAL("relay._expr.FunctionGetAttr")
.set_body_typed(
[](Function func, std::string name) {
return FunctionGetAttr(func, name);
});
TVM_REGISTER_GLOBAL("relay._make.Any")
.set_body_typed([]() { return Any::make(); });
......
......@@ -168,10 +168,12 @@ def test_function():
body = relay.Tuple(tvm.convert([]))
type_params = tvm.convert([])
fn = relay.Function(params, body, ret_type, type_params)
fn = fn.set_attribute("test_attribute", tvm.tir.StringImm("value"))
assert fn.params == params
assert fn.body == body
assert fn.type_params == type_params
assert fn.span == None
assert fn.get_attribute("test_attribute") == "value"
str(fn)
check_json_roundtrip(fn)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment