Commit db841c24 by Steven S. Lyubomirsky Committed by Jared Roesch

[Relay][Testing] Relay-to-Python compilation (#3156)

* First pass at Relay-to-Python converter testing utility

* Indicate astor as a dependency

* Add astor dep to host as well

* Typos and small bugs

* Handle ADTs and matching in Python conversion

* Remove any dependency on ast.parse

* Eliminate unnecessary type var field in Python version of ConstructorValue (already gone on C++ side)

* Update constructor value, fix syntax errors

* Don't forget keywords arg on Call nodes

* Fix some incorrect calls to ast nodes

* Fix more calls, a little more cleaning up

* Missing cases in attr conversion

* Lower op calls instead of running them through interpreter, as in @MarisaKirisame's AoT compiler

* We do still need the module

* Remove changes to op attrs: Will PR separately

* Smoke test and corrections

* More tests and fixes

* Ensure imports are properly global in generated Python code

* Add unit tests for refs

* Add unit test for tuple indexing

* Add unit test for if expression

* Remove astor dependency

* Remove astor from meta.yaml too

* Fix if test and add basic local function test

* Add global function test, refactor earlier tests

* Correct 'clause' field in ADT so Python and C++ field names match

* More fixes and tests for matching and constructors

* Dramatically simplify matching: no need for a thunk

* Improve ref writing test

* Ensure local recursion works

* cleanup

* Add test for global recursion

* Add test for higher-order calls

* Get ops working, add basic tests

* Remove accidentally duplicated test

* More docstrings to appease pylint

* Forgot to fix a test using constructor values

* Reduce optimization level in fusion and fix tuple input to operators

* Test op with tuple output, fix tuple output code

* Add unit test for batch norm

* Add a couple more tricky test cases

* Correct nat constructor to drop unnecessary field

* Fix the op attrs file (accidentally reduced it)

* Address review comments

* Adapt to new ConstructorValue representation (no more runtime dep on module)

* Use pass manager and updated interfaces. Extend module.from_expr to accommodate necessary demands

* Use sequential return value

* Lift out nested conditionals

* Replace triple single quotes with triple double quotes

* Use main variable instead of entry_func
parent 93d1c06d
...@@ -243,7 +243,7 @@ class MatchNode : public ExprNode { ...@@ -243,7 +243,7 @@ class MatchNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data); v->Visit("data", &data);
v->Visit("clause", &clauses); v->Visit("clauses", &clauses);
v->Visit("span", &span); v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
......
...@@ -180,17 +180,19 @@ class ModuleNode : public RelayNode { ...@@ -180,17 +180,19 @@ class ModuleNode : public RelayNode {
/*! \brief Construct a module from a standalone expression. /*! \brief Construct a module from a standalone expression.
* *
* Allows one to optionally pass a global function map as * Allows one to optionally pass a global function map and
* well. * map of type definitions as well.
* *
* \param expr The expression to set as the main function to the module. * \param expr The expression to set as the main function to the module.
* \param global_funcs The global function map. * \param global_funcs The global function map.
* \param type_definitions Map of global type definitions
* *
* \returns A module with expr set as the main function. * \returns A module with expr set as the main function.
*/ */
TVM_DLL static Module FromExpr( TVM_DLL static Module FromExpr(
const Expr& expr, const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs = {}); const tvm::Map<GlobalVar, Function>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});
static constexpr const char* _type_key = "relay.Module"; static constexpr const char* _type_key = "relay.Module";
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node); TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
......
...@@ -74,9 +74,9 @@ class Closure(Value): ...@@ -74,9 +74,9 @@ class Closure(Value):
@register_relay_node @register_relay_node
class ConstructorValue(Value): class ConstructorValue(Value):
def __init__(self, tag, fields, constructor, types): def __init__(self, tag, fields, constructor):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.ConstructorValue, tag, fields, constructor, types) _make.ConstructorValue, tag, fields, constructor)
@register_relay_node @register_relay_node
......
...@@ -183,7 +183,7 @@ class ExprVisitor(ExprFunctor): ...@@ -183,7 +183,7 @@ class ExprVisitor(ExprFunctor):
def visit_match(self, m): def visit_match(self, m):
self.visit(m.data) self.visit(m.data)
for c in m.clause: for c in m.clauses:
self.visit(c.rhs) self.visit(c.rhs)
......
...@@ -179,5 +179,26 @@ class Module(RelayNode): ...@@ -179,5 +179,26 @@ class Module(RelayNode):
return _module.Module_LookupTag(self, tag) return _module.Module_LookupTag(self, tag)
@staticmethod @staticmethod
def from_expr(expr): def from_expr(expr, functions=None, type_defs=None):
return _module.Module_FromExpr(expr) """Construct a module from a standalone expression.
Parameters
----------
expr: Expr
The starting expression
global_funcs: Optional[dict]
Map of global vars to function definitions
type_defs: Optional[dict]
Map of global type vars to type definitions
Returns
-------
mod: Module
A module containing the passed definitions,
where expr is set as the entry point
(wrapped in a function if necessary)
"""
funcs = functions if functions is not None else {}
defs = type_defs if type_defs is not None else {}
return _module.Module_FromExpr(expr, funcs, defs)
...@@ -35,6 +35,7 @@ from . import yolo_detection ...@@ -35,6 +35,7 @@ from . import yolo_detection
from .config import ctx_list from .config import ctx_list
from .init import create_workload from .init import create_workload
from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr
from .py_converter import to_python, run_as_python
def run_opt_pass(expr, opt_pass): def run_opt_pass(expr, opt_pass):
......
...@@ -168,8 +168,8 @@ def make_nat_value(prelude, n): ...@@ -168,8 +168,8 @@ def make_nat_value(prelude, n):
constructs a ConstructorValue representing that value as a nat. constructs a ConstructorValue representing that value as a nat.
""" """
if n == 0: if n == 0:
return ConstructorValue(prelude.z.tag, [], None, []) return ConstructorValue(prelude.z.tag, [], None)
return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None, []) return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None)
def make_nat_expr(prelude, n): def make_nat_expr(prelude, n):
......
...@@ -187,8 +187,9 @@ void ModuleNode::Update(const Module& mod) { ...@@ -187,8 +187,9 @@ void ModuleNode::Update(const Module& mod) {
Module ModuleNode::FromExpr( Module ModuleNode::FromExpr(
const Expr& expr, const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs) { const tvm::Map<GlobalVar, Function>& global_funcs,
auto mod = ModuleNode::make(global_funcs, {}); const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = ModuleNode::make(global_funcs, type_definitions);
auto func_node = expr.as<FunctionNode>(); auto func_node = expr.as<FunctionNode>();
Function func; Function func;
if (func_node) { if (func_node) {
...@@ -266,9 +267,14 @@ TVM_REGISTER_API("relay._module.Module_LookupTag") ...@@ -266,9 +267,14 @@ TVM_REGISTER_API("relay._module.Module_LookupTag")
}); });
TVM_REGISTER_API("relay._module.Module_FromExpr") TVM_REGISTER_API("relay._module.Module_FromExpr")
.set_body_typed<Module(Expr)>([](Expr e) { .set_body_typed<
return ModuleNode::FromExpr(e); Module(Expr,
}); tvm::Map<GlobalVar, Function>,
tvm::Map<GlobalTypeVar, TypeData>)>([](Expr e,
tvm::Map<GlobalVar, Function> funcs,
tvm::Map<GlobalTypeVar, TypeData> type_defs) {
return ModuleNode::FromExpr(e, funcs, type_defs);
});
TVM_REGISTER_API("relay._module.Module_Update") TVM_REGISTER_API("relay._module.Module_Update")
.set_body_typed<void(Module, Module)>([](Module mod, Module from) { .set_body_typed<void(Module, Module)>([](Module mod, Module from) {
......
...@@ -75,9 +75,9 @@ iterate = p.iterate ...@@ -75,9 +75,9 @@ iterate = p.iterate
# this is an example of creating the adt value in python side # this is an example of creating the adt value in python side
def make_nat(n): def make_nat(n):
if n != 0: if n != 0:
return ConstructorValue(s, [make_nat(n - 1)], []) return ConstructorValue(s, [make_nat(n - 1)])
else: else:
return ConstructorValue(z, [], []) return ConstructorValue(z, [])
def make_nat_expr(n): def make_nat_expr(n):
assert n >= 0 assert n >= 0
......
...@@ -183,11 +183,11 @@ def test_function_taking_adt_ref_tuple(): ...@@ -183,11 +183,11 @@ def test_function_taking_adt_ref_tuple():
prelude = relay.prelude.Prelude(mod) prelude = relay.prelude.Prelude(mod)
intrp = create_executor("debug", mod) intrp = create_executor("debug", mod)
nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil, []) nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil)
cons_value = ConstructorValue(prelude.cons.tag, [ cons_value = ConstructorValue(prelude.cons.tag, [
TensorValue(np.random.rand(1, 10).astype('float32')), TensorValue(np.random.rand(1, 10).astype('float32')),
nil_value nil_value
], prelude.cons, [relay.TensorType((1, 10), 'float32')]) ], prelude.cons)
ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32'))) ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32')))
tuple_value = TupleValue(*[ tuple_value = TupleValue(*[
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment