Commit db841c24 by Steven S. Lyubomirsky Committed by Jared Roesch

[Relay][Testing] Relay-to-Python compilation (#3156)

* First pass at Relay-to-Python converter testing utility

* Indicate astor as a dependency

* Add astor dep to host as well

* Typos and small bugs

* Handle ADTs and matching in Python conversion

* Remove any dependency on ast.parse

* Eliminate unnecessary type var field in Python version of ConstructorValue (already gone on C++ side)

* Update constructor value, fix syntax errors

* Don't forget keywords arg on Call nodes

* Fix some incorrect calls to ast nodes

* Fix more calls, a little more cleaning up

* Missing cases in attr conversion

* Lower op calls instead of running them through interpreter, as in @MarisaKirisame's AoT compiler

* We do still need the module

* Remove changes to op attrs: Will PR separately

* Smoke test and corrections

* More tests and fixes

* Ensure imports are properly global in generated Python code

* Add unit tests for refs

* Add unit test for tuple indexing

* Add unit test for if expression

* Remove astor dependency

* Remove astor from meta.yaml too

* Fix if test and add basic local function test

* Add global function test, refactor earlier tests

* Correct 'clause' field in ADT so Python and C++ field names match

* More fixes and tests for matching and constructors

* Dramatically simplify matching: no need for a thunk

* Improve ref writing test

* Ensure local recursion works

* cleanup

* Add test for global recursion

* Add test for higher-order calls

* Get ops working, add basic tests

* Remove accidentally duplicated test

* More docstrings to appease pylint

* Forgot to fix a test using constructor values

* Reduce optimization level in fusion and fix tuple input to operators

* Test op with tuple output, fix tuple output code

* Add unit test for batch norm

* Add a couple more tricky test cases

* Correct nat constructor to drop unnecessary field

* Fix the op attrs file (accidentally reduced it)

* Address review comments

* Adapt to new ConstructorValue representation (no more runtime dep on module)

* Use pass manager and updated interfaces. Extend module.from_expr to accommodate necessary demands

* Use sequential return value

* Lift out nested conditionals

* Replace triple single quotes with triple double quotes

* Use main variable instead of entry_func
parent 93d1c06d
......@@ -243,7 +243,7 @@ class MatchNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data);
v->Visit("clause", &clauses);
v->Visit("clauses", &clauses);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
......
......@@ -180,17 +180,19 @@ class ModuleNode : public RelayNode {
/*! \brief Construct a module from a standalone expression.
*
* Allows one to optionally pass a global function map as
* well.
* Allows one to optionally pass a global function map and
* map of type definitions as well.
*
* \param expr The expression to set as the main function to the module.
* \param global_funcs The global function map.
* \param type_definitions Map of global type definitions
*
* \returns A module with expr set as the main function.
*/
TVM_DLL static Module FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs = {});
const tvm::Map<GlobalVar, Function>& global_funcs = {},
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions = {});
static constexpr const char* _type_key = "relay.Module";
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
......
......@@ -74,9 +74,9 @@ class Closure(Value):
@register_relay_node
class ConstructorValue(Value):
def __init__(self, tag, fields, constructor, types):
def __init__(self, tag, fields, constructor):
self.__init_handle_by_constructor__(
_make.ConstructorValue, tag, fields, constructor, types)
_make.ConstructorValue, tag, fields, constructor)
@register_relay_node
......
......@@ -183,7 +183,7 @@ class ExprVisitor(ExprFunctor):
def visit_match(self, m):
self.visit(m.data)
for c in m.clause:
for c in m.clauses:
self.visit(c.rhs)
......
......@@ -179,5 +179,26 @@ class Module(RelayNode):
return _module.Module_LookupTag(self, tag)
@staticmethod
def from_expr(expr):
return _module.Module_FromExpr(expr)
def from_expr(expr, functions=None, type_defs=None):
"""Construct a module from a standalone expression.
Parameters
----------
expr: Expr
The starting expression
global_funcs: Optional[dict]
Map of global vars to function definitions
type_defs: Optional[dict]
Map of global type vars to type definitions
Returns
-------
mod: Module
A module containing the passed definitions,
where expr is set as the entry point
(wrapped in a function if necessary)
"""
funcs = functions if functions is not None else {}
defs = type_defs if type_defs is not None else {}
return _module.Module_FromExpr(expr, funcs, defs)
......@@ -35,6 +35,7 @@ from . import yolo_detection
from .config import ctx_list
from .init import create_workload
from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr
from .py_converter import to_python, run_as_python
def run_opt_pass(expr, opt_pass):
......
......@@ -168,8 +168,8 @@ def make_nat_value(prelude, n):
constructs a ConstructorValue representing that value as a nat.
"""
if n == 0:
return ConstructorValue(prelude.z.tag, [], None, [])
return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None, [])
return ConstructorValue(prelude.z.tag, [], None)
return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None)
def make_nat_expr(prelude, n):
......
......@@ -187,8 +187,9 @@ void ModuleNode::Update(const Module& mod) {
Module ModuleNode::FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs) {
auto mod = ModuleNode::make(global_funcs, {});
const tvm::Map<GlobalVar, Function>& global_funcs,
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = ModuleNode::make(global_funcs, type_definitions);
auto func_node = expr.as<FunctionNode>();
Function func;
if (func_node) {
......@@ -266,9 +267,14 @@ TVM_REGISTER_API("relay._module.Module_LookupTag")
});
TVM_REGISTER_API("relay._module.Module_FromExpr")
.set_body_typed<Module(Expr)>([](Expr e) {
return ModuleNode::FromExpr(e);
});
.set_body_typed<
Module(Expr,
tvm::Map<GlobalVar, Function>,
tvm::Map<GlobalTypeVar, TypeData>)>([](Expr e,
tvm::Map<GlobalVar, Function> funcs,
tvm::Map<GlobalTypeVar, TypeData> type_defs) {
return ModuleNode::FromExpr(e, funcs, type_defs);
});
TVM_REGISTER_API("relay._module.Module_Update")
.set_body_typed<void(Module, Module)>([](Module mod, Module from) {
......
......@@ -75,9 +75,9 @@ iterate = p.iterate
# this is an example of creating the adt value in python side
def make_nat(n):
if n != 0:
return ConstructorValue(s, [make_nat(n - 1)], [])
return ConstructorValue(s, [make_nat(n - 1)])
else:
return ConstructorValue(z, [], [])
return ConstructorValue(z, [])
def make_nat_expr(n):
assert n >= 0
......
......@@ -183,11 +183,11 @@ def test_function_taking_adt_ref_tuple():
prelude = relay.prelude.Prelude(mod)
intrp = create_executor("debug", mod)
nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil, [])
nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil)
cons_value = ConstructorValue(prelude.cons.tag, [
TensorValue(np.random.rand(1, 10).astype('float32')),
nil_value
], prelude.cons, [relay.TensorType((1, 10), 'float32')])
], prelude.cons)
ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32')))
tuple_value = TupleValue(*[
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment