Commit 2dac17d8 by Logan Weber Committed by Jared Roesch

[Relay] Move prelude to text format (#3939)

* Fix parser

* Doc fix

* Add module utility functions necessary for prelude

* Implement prelude in text format

* Remove programmatically constructed prelude defs

* Fix 0-arity type conses in pretty printer and test

* Make prelude loading backwards-compatible

* Fix patterns

* Improve some prelude defs

* Fix `ImportFromStd`

It needs to also follow the "add unchecked, add checked" pattern

* Lint roller

* Woops

* Address feedback

* Fix `test_list_constructor` VM test

* Fix `test_adt.py` failures
parent 9b46ace1
...@@ -92,7 +92,7 @@ class Var; ...@@ -92,7 +92,7 @@ class Var;
/*! /*!
* \brief A variable node in the IR. * \brief A variable node in the IR.
* *
* A vraible is uniquely identified by its address. * A variable is uniquely identified by its address.
* *
* Each variable is only binded once in the following nodes: * Each variable is only binded once in the following nodes:
* - Allocate * - Allocate
......
...@@ -117,7 +117,8 @@ class ExprFunctor<R(const Expr& n, Args...)> { ...@@ -117,7 +117,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) { virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key()); LOG(FATAL) << "Do not have a default for " << op->type_key();
throw;
} }
private: private:
......
...@@ -88,20 +88,33 @@ class ModuleNode : public RelayNode { ...@@ -88,20 +88,33 @@ class ModuleNode : public RelayNode {
TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false); TVM_DLL void Add(const GlobalVar& var, const Function& func, bool update = false);
/*! /*!
* \brief Add a function to the global environment.
* \param var The name of the global function.
* \param func The function.
*
* It does not do type inference as Add does.
*/
TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func);
/*!
* \brief Add a type-level definition to the global environment. * \brief Add a type-level definition to the global environment.
* \param var The var of the global type definition. * \param var The var of the global type definition.
* \param type The type definition. * \param type The ADT.
* \param update Controls whether you can replace a definition in the
* environment.
*/ */
TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type); TVM_DLL void AddDef(const GlobalTypeVar& var, const TypeData& type, bool update = false);
/*! /*!
* \brief Add a function to the global environment. * \brief Add a type definition to the global environment.
* \param var The name of the global function. * \param var The name of the global function.
* \param func The function. * \param type The ADT.
* \param update Controls whether you can replace a definition in the
* environment.
* *
* It does not do type inference as Add does. * It does not do type inference as AddDef does.
*/ */
TVM_DLL void AddUnchecked(const GlobalVar& var, const Function& func); TVM_DLL void AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update = false);
/*! /*!
* \brief Update a function in the global environment. * \brief Update a function in the global environment.
...@@ -111,6 +124,13 @@ class ModuleNode : public RelayNode { ...@@ -111,6 +124,13 @@ class ModuleNode : public RelayNode {
TVM_DLL void Update(const GlobalVar& var, const Function& func); TVM_DLL void Update(const GlobalVar& var, const Function& func);
/*! /*!
* \brief Update a type definition in the global environment.
* \param var The name of the global type definition to update.
* \param type The new ADT.
*/
TVM_DLL void UpdateDef(const GlobalTypeVar& var, const TypeData& type);
/*!
* \brief Remove a function from the global environment. * \brief Remove a function from the global environment.
* \param var The name of the global function to update. * \param var The name of the global function to update.
*/ */
...@@ -131,6 +151,12 @@ class ModuleNode : public RelayNode { ...@@ -131,6 +151,12 @@ class ModuleNode : public RelayNode {
TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const; TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const;
/*! /*!
* \brief Collect all global vars defined in this module.
* \returns An array of global vars
*/
tvm::Array<GlobalVar> GetGlobalVars() const;
/*!
* \brief Look up a global function by its name. * \brief Look up a global function by its name.
* \param str The unique string specifying the global variable. * \param str The unique string specifying the global variable.
* \returns The global variable. * \returns The global variable.
...@@ -138,6 +164,12 @@ class ModuleNode : public RelayNode { ...@@ -138,6 +164,12 @@ class ModuleNode : public RelayNode {
TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const; TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const;
/*! /*!
* \brief Collect all global type vars defined in this module.
* \returns An array of global type vars
*/
tvm::Array<GlobalTypeVar> GetGlobalTypeVars() const;
/*!
* \brief Look up a global function by its variable. * \brief Look up a global function by its variable.
* \param var The global var to lookup. * \param var The global var to lookup.
* \returns The function named by the variable argument. * \returns The function named by the variable argument.
......
...@@ -103,7 +103,8 @@ class PatternFunctor<R(const Pattern& n, Args...)> { ...@@ -103,7 +103,8 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
virtual R VisitPattern_(const PatternTupleNode* op, virtual R VisitPattern_(const PatternTupleNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT; Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPatternDefault_(const Node* op, Args...) { virtual R VisitPatternDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key()); LOG(FATAL) << "Do not have a default for " << op->type_key();
throw;
} }
private: private:
......
...@@ -87,33 +87,33 @@ callList ...@@ -87,33 +87,33 @@ callList
expr expr
// operators // operators
: '(' expr ')' # paren : '(' expr ')' # paren
// function application // function application
| expr '(' callList ')' # call | expr '(' callList ')' # call
| '-' expr # neg | '-' expr # neg
| expr op=('*'|'/') expr # binOp | expr op=('*'|'/') expr # binOp
| expr op=('+'|'-') expr # binOp | expr op=('+'|'-') expr # binOp
| expr op=('<'|'>'|'<='|'>=') expr # binOp | expr op=('<'|'>'|'<='|'>=') expr # binOp
| expr op=('=='|'!=') expr # binOp | expr op=('=='|'!=') expr # binOp
// function definition // function definition
| func # funcExpr | func # funcExpr
// tuples and tensors // tuples and tensors
| '(' ')' # tuple | '(' ')' # tuple
| '(' expr ',' ')' # tuple | '(' expr ',' ')' # tuple
| '(' expr (',' expr)+ ')' # tuple | '(' expr (',' expr)+ ')' # tuple
| '[' (expr (',' expr)*)? ']' # tensor | '[' (expr (',' expr)*)? ']' # tensor
| 'if' '(' expr ')' body 'else' body # ifElse | 'if' '(' expr ')' body 'else' body # ifElse
| matchType '(' expr ')' '{' matchClauseList? '}' # match | matchType expr '{' matchClauseList? '}' # match
| expr '.' NAT # projection | expr '.' NAT # projection
// sequencing // sequencing
| 'let' var '=' expr ';' expr # let | 'let' var '=' expr ';' expr # let
// sugar for let %_ = expr; expr // sugar for let %_ = expr; expr
| expr ';;' expr # let | expr ';;' expr # let
| graphVar '=' expr ';' expr # graph | graphVar '=' expr ';' expr # graph
| ident # identExpr | ident # identExpr
| scalar # scalarExpr | scalar # scalarExpr
| meta # metaExpr | meta # metaExpr
| QUOTED_STRING # stringExpr | QUOTED_STRING # stringExpr
; ;
func: 'fn' typeParamList? '(' argList ')' ('->' typeExpr)? body ; func: 'fn' typeParamList? '(' argList ')' ('->' typeExpr)? body ;
...@@ -128,14 +128,16 @@ constructorName: CNAME ; ...@@ -128,14 +128,16 @@ constructorName: CNAME ;
adtConsDefnList: adtConsDefn (',' adtConsDefn)* ','? ; adtConsDefnList: adtConsDefn (',' adtConsDefn)* ','? ;
adtConsDefn: constructorName ('(' typeExpr (',' typeExpr)* ')')? ; adtConsDefn: constructorName ('(' typeExpr (',' typeExpr)* ')')? ;
matchClauseList: matchClause (',' matchClause)* ','? ; matchClauseList: matchClause (',' matchClause)* ','? ;
matchClause: constructorName patternList? '=>' ('{' expr '}' | expr) ; matchClause: pattern '=>' ('{' expr '}' | expr) ;
// complete or incomplete match, respectively // complete or incomplete match, respectively
matchType : 'match' | 'match?' ; matchType : 'match' | 'match?' ;
patternList: '(' pattern (',' pattern)* ')'; patternList: '(' pattern (',' pattern)* ')';
pattern pattern
: '_' : '_' # wildcardPattern
| localVar (':' typeExpr)? | localVar (':' typeExpr)? # varPattern
| constructorName patternList? # constructorPattern
| patternList # tuplePattern
; ;
adtCons: constructorName adtConsParamList? ; adtCons: constructorName adtConsParamList? ;
...@@ -155,6 +157,7 @@ attr: CNAME '=' expr ; ...@@ -155,6 +157,7 @@ attr: CNAME '=' expr ;
typeExpr typeExpr
: '(' ')' # tupleType : '(' ')' # tupleType
| '(' typeExpr ')' # typeParen
| '(' typeExpr ',' ')' # tupleType | '(' typeExpr ',' ')' # tupleType
| '(' typeExpr (',' typeExpr)+ ')' # tupleType | '(' typeExpr (',' typeExpr)+ ')' # tupleType
| generalIdent typeParamList # typeCallType | generalIdent typeParamList # typeCallType
...@@ -164,7 +167,7 @@ typeExpr ...@@ -164,7 +167,7 @@ typeExpr
| '_' # incompleteType | '_' # incompleteType
; ;
typeParamList: '[' generalIdent (',' generalIdent)* ']' ; typeParamList: '[' typeExpr (',' typeExpr)* ']' ;
shapeList shapeList
: '(' ')' : '(' ')'
......
...@@ -184,8 +184,23 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -184,8 +184,23 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#pattern. # Visit a parse tree produced by RelayParser#wildcardPattern.
def visitPattern(self, ctx:RelayParser.PatternContext): def visitWildcardPattern(self, ctx:RelayParser.WildcardPatternContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#varPattern.
def visitVarPattern(self, ctx:RelayParser.VarPatternContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#constructorPattern.
def visitConstructorPattern(self, ctx:RelayParser.ConstructorPatternContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#tuplePattern.
def visitTuplePattern(self, ctx:RelayParser.TuplePatternContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -239,6 +254,11 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -239,6 +254,11 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#typeParen.
def visitTypeParen(self, ctx:RelayParser.TypeParenContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#typeCallType. # Visit a parse tree produced by RelayParser#typeCallType.
def visitTypeCallType(self, ctx:RelayParser.TypeCallTypeContext): def visitTypeCallType(self, ctx:RelayParser.TypeCallTypeContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
......
...@@ -96,7 +96,7 @@ class Module(RelayNode): ...@@ -96,7 +96,7 @@ class Module(RelayNode):
assert isinstance(val, _ty.Type) assert isinstance(val, _ty.Type)
if isinstance(var, _base.string_types): if isinstance(var, _base.string_types):
var = _ty.GlobalTypeVar(var) var = _ty.GlobalTypeVar(var)
_module.Module_AddDef(self, var, val) _module.Module_AddDef(self, var, val, update)
def __getitem__(self, var): def __getitem__(self, var):
"""Lookup a global definition by name or by variable. """Lookup a global definition by name or by variable.
...@@ -149,6 +149,26 @@ class Module(RelayNode): ...@@ -149,6 +149,26 @@ class Module(RelayNode):
""" """
return _module.Module_GetGlobalVar(self, name) return _module.Module_GetGlobalVar(self, name)
def get_global_vars(self):
"""Collect all global vars defined in this module.
Returns
-------
global_vars: tvm.Array[GlobalVar]
An array of global vars.
"""
return _module.Module_GetGlobalVars(self)
def get_global_type_vars(self):
"""Collect all global type vars defined in this module.
Returns
-------
global_type_vars: tvm.Array[GlobalTypeVar]
An array of global type vars.
"""
return _module.Module_GetGlobalTypeVars(self)
def get_global_type_var(self, name): def get_global_type_var(self, name):
"""Get a global type variable in the function by name. """Get a global type variable in the function by name.
......
...@@ -18,12 +18,299 @@ ...@@ -18,12 +18,299 @@
*/ */
v0.0.4 v0.0.4
def @id[a](%x: a) -> a { // TODO(weberlo): should we add sugar for scalar types (e.g., `int32` => `Tensor[(), int32]`)?
%x
def @id[A](%x: A) -> A {
%x
}
def @compose[A, B, C](%f: fn(B) -> C, %g: fn(A) -> B) {
fn (%x: A) -> C {
%f(%g(%x))
}
}
def @flip[A, B, C](%f: fn(A, B) -> C) -> fn(B, A) -> C {
fn(%b: B, %a: A) -> C {
%f(%a, %b)
}
}
/*
* A LISP-style list ADT. An empty list is represented by `Nil`, and a member
* `x` can be appended to the front of a list `l` via the constructor `Cons(x, l)`.
*/
type List[A] {
Cons(A, List[A]),
Nil,
}
/*
* Get the head of a list. Assume the list has at least one element.
*/
def @hd[A](%xs: List[A]) -> A {
match? (%xs) {
Cons(%x, _) => %x,
}
}
/*
* Get the tail of a list.
*/
def @tl[A](%xs: List[A]) -> List[A] {
match? (%xs) {
Cons(_, %rest) => %rest,
}
}
/*
* Get the `n`th element of a list.
*/
def @nth[A](%xs: List[A], %n: Tensor[(), int32]) -> A {
if (%n == 0) {
@hd(%xs)
} else {
@nth(@tl(%xs), %n - 1)
}
}
/*
* Return the length of a list.
*/
def @length[A](%xs: List[A]) -> Tensor[(), int32] {
match (%xs) {
Cons(_, %rest) => 1 + @length(%rest),
Nil => 0,
}
}
/*
* Update the `n`th element of a list and return the updated list.
*/
def @update[A](%xs: List[A], %n: Tensor[(), int32], %v: A) -> List[A] {
if (%n == 0) {
Cons(%v, @tl(%xs))
} else {
Cons(@hd(%xs), @update(@tl(%xs), %n - 1, %v))
}
}
/*
* Map a function over a list's elements. That is, `map(f, xs)` returns a new
* list where the `i`th member is `f` applied to the `i`th member of `xs`.
*/
def @map[A, B](%f: fn(A) -> B, %xs: List[A]) -> List[B] {
match (%xs) {
Cons(%x, %rest) => Cons(%f(%x), @map(%f, %rest)),
Nil => Nil,
}
}
/*
* A left-way fold over a list.
*
* `foldl(f, z, cons(a1, cons(a2, cons(a3, cons(..., nil)))))`
* evaluates to `f(...f(f(f(z, a1), a2), a3)...)`.
*/
def @foldl[A, B](%f: fn(A, B) -> A, %acc: A, %xs: List[B]) -> A {
match (%xs) {
Cons(%x, %rest) => @foldl(%f, %f(%acc, %x), %rest),
Nil => %acc,
}
} }
def @compose[a, b, c](%f: fn(b) -> c, %g: fn(a) -> b) { /*
fn (%x: a) -> c { * A right-way fold over a list.
%f(%g(%x)) *
} * `foldr(f, z, cons(a1, cons(a2, cons(..., cons(an, nil)))))`
* evaluates to `f(a1, f(a2, f(..., f(an, z)))...)`.
*/
def @foldr[A, B](%f: fn(A, B) -> B, %acc: B, %xs: List[A]) -> B {
match (%xs) {
Cons(%x, %rest) => %f(%x, @foldr(%f, %acc, %rest)),
Nil => %acc,
}
}
/*
* A right-way fold over a nonempty list.
*
* `foldr1(f, cons(a1, cons(a2, cons(..., cons(an, nil)))))`
* evaluates to `f(a1, f(a2, f(..., f(an-1, an)))...)`
*/
def @foldr1[A](%f: fn(A, A) -> A, %xs: List[A]) -> A {
match? (%xs) {
Cons(%x, Nil) => %x,
Cons(%x, %rest) => %f(%x, @foldr1(%f, %rest)),
}
}
/*
* Computes the sum of a list of integer scalars.
*/
def @sum(%xs: List[Tensor[(), int32]]) {
let %add_f = fn(%x: Tensor[(), int32], %y: Tensor[(), int32]) -> Tensor[(), int32] {
%x + %y
};
@foldl(%add_f, 0, %xs)
}
/*
* Concatenates two lists.
*/
def @concat[A](%xs: List[A], %ys: List[A]) -> List[A] {
let %updater = fn(%x: A, %xss: List[A]) -> List[A] {
Cons(%x, %xss)
};
@foldr(%updater, %ys, %xs)
// TODO(weberlo): write it like below, once VM constructor compilation is fixed
// @foldr(Cons, %ys, %xs)
}
/*
* Filters a list, returning a sublist of only the values which satisfy the given predicate.
*/
def @filter[A](%f: fn(A) -> Tensor[(), bool], %xs: List[A]) -> List[A] {
match (%xs) {
Cons(%x, %rest) => {
if (%f(%x)) {
Cons(%x, @filter(%f, %rest))
} else {
@filter(%f, %rest)
}
},
Nil => Nil,
}
}
/*
* Combines two lists into a list of tuples of their elements.
*
* The zipped list will be the length of the shorter list.
*/
def @zip[A, B](%xs: List[A], %ys: List[B]) -> List[(A, B)] {
match (%xs, %ys) {
(Cons(%x, %x_rest), Cons(%y, %y_rest)) => Cons((%x, %y), @zip(%x_rest, %y_rest)),
_ => Nil,
}
}
/*
* Reverses a list.
*/
def @rev[A](%xs: List[A]) -> List[A] {
let %updater = fn(%xss: List[A], %x: A) -> List[A] {
Cons(%x, %xss)
};
@foldl(%updater, Nil, %xs)
// TODO(weberlo): write it like below, once VM constructor compilation is fixed
// @foldl(@flip(Cons), Nil, %xs)
}
/*
* An accumulative map, which is a fold that simulataneously updates an
* accumulator value and a list of results.
*
* This map proceeds through the list from right to left.
*/
def @map_accumr[A, B, C](%f: fn(A, B) -> (A, C), %init: A, %xs: List[B]) -> (A, List[C]) {
let %updater = fn(%x: B, %acc: (A, List[C])) -> (A, List[C]) {
let %f_out = %f(%acc.0, %x);
(%f_out.0, Cons(%f_out.1, %acc.1))
};
@foldr(%updater, (%init, Nil), %xs)
}
/*
* an accumulative map, which is a fold that simulataneously updates an
* accumulator value and a list of results.
*
* This map proceeds through the list from left to right.
*/
def @map_accuml[A, B, C](%f: fn(A, B) -> (A, C), %init: A, %xs: List[B]) -> (A, List[C]) {
let %updater = fn(%acc: (A, List[C]), %x: B) -> (A, List[C]) {
let %f_out = %f(%acc.0, %x);
(%f_out.0, Cons(%f_out.1, %acc.1))
};
@foldl(%updater, (%init, Nil), %xs)
}
/*
* An optional ADT, which can either contain some other type or nothing at all.
*/
type Option[A] {
Some(A),
None,
}
/*
* Builds up a list starting from a seed value.
*
* `f` returns an option containing a new seed and an output value. `f` will
* continue to be called on the new seeds until it returns `None`. All the output
* values will be combined into a list, right to left.
*/
def @unfoldr[A, B](%f: fn(A) -> Option[(A, B)], %seed: A) -> List[B] {
match (%f(%seed)) {
Some(%val) => Cons(%val.1, @unfoldr(%f, %val.0)),
None => Nil,
}
}
/*
* Builds up a list starting from a seed value.
*
* `f` returns an option containing a new seed and an output value. `f` will
* continue to be called on the new seeds until it returns `None`. All the
* output values will be combined into a list, left to right.
*/
def @unfoldl[A, B](%f: fn(A) -> Option[(A, B)], %seed: A) -> List[B] {
@rev(@unfoldr(%f, %seed))
}
/*
* A tree ADT. A tree can contain any type. It has only one
* constructor, rose(x, l), where x is the content of that point of the tree
* and l is a list of more trees of the same type. A leaf is thus rose(x,
* nil()).
*/
type Tree[A] {
Rose(A, List[Tree[A]]),
}
/*
* Maps over a tree. The function is applied to each subtree's contents.
*/
def @tmap[A, B](%f: fn(A) -> B, %t: Tree[A]) -> Tree[B] {
match(%t) {
Rose(%v, %sub_trees) => {
let %list_f = fn(%tt: Tree[A]) -> Tree[B] {
@tmap(%f, %tt)
};
Rose(%f(%v), @map(%list_f, %sub_trees))
},
}
}
/*
* Computes the size of a tree.
*/
def @size[A](%t: Tree[A]) -> Tensor[(), int32] {
match(%t) {
Rose(_, %sub_trees) => {
1 + @sum(@map(@size, %sub_trees))
},
}
}
/*
* Takes a number n and a function f; returns a closure that takes an argument
* and applies f n times to its argument.
*/
def @iterate[A](%f: fn(A) -> A, %n: Tensor[(), int32]) -> (fn(A) -> A) {
if (%n == 0) {
@id
} else {
@compose(%f, @iterate(%f, %n - 1))
}
} }
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......
...@@ -46,14 +46,14 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs, ...@@ -46,14 +46,14 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
for (const auto& kv : n->functions) { for (const auto& kv : n->functions) {
// set global var map // set global var map
CHECK(!n->global_var_map_.count(kv.first->name_hint)) CHECK(n->global_var_map_.count(kv.first->name_hint) == 0)
<< "Duplicate global function name " << kv.first->name_hint; << "Duplicate global function name " << kv.first->name_hint;
n->global_var_map_.Set(kv.first->name_hint, kv.first); n->global_var_map_.Set(kv.first->name_hint, kv.first);
} }
for (const auto& kv : n->type_definitions) { for (const auto& kv : n->type_definitions) {
// set global typevar map // set global typevar map
CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint)) CHECK(n->global_type_var_map_.count(kv.first->var->name_hint) == 0)
<< "Duplicate global type definition name " << kv.first->var->name_hint; << "Duplicate global type definition name " << kv.first->var->name_hint;
n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first); n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first);
n->RegisterConstructors(kv.first, kv.second); n->RegisterConstructors(kv.first, kv.second);
...@@ -73,20 +73,12 @@ GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const { ...@@ -73,20 +73,12 @@ GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const {
return (*it).second; return (*it).second;
} }
void ModuleNode::AddUnchecked(const GlobalVar& var, tvm::Array<GlobalVar> ModuleNode::GetGlobalVars() const {
const Function& func) { std::vector<GlobalVar> global_vars;
auto mod = GetRef<Module>(this); for (const auto& pair : global_var_map_) {
this->functions.Set(var, func); global_vars.push_back(pair.second);
auto it = global_var_map_.find(var->name_hint);
if (it != global_var_map_.end()) {
CHECK_EQ((*it).second, var);
} else {
CHECK(!global_var_map_.count(var->name_hint))
<< "Duplicate global function name " << var->name_hint;
} }
return tvm::Array<GlobalVar>(global_vars);
global_var_map_.Set(var->name_hint, var);
} }
GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const { GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
...@@ -97,6 +89,14 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const { ...@@ -97,6 +89,14 @@ GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
return (*it).second; return (*it).second;
} }
tvm::Array<GlobalTypeVar> ModuleNode::GetGlobalTypeVars() const {
std::vector<GlobalTypeVar> global_type_vars;
for (const auto& pair : global_type_var_map_) {
global_type_vars.push_back(pair.second);
}
return tvm::Array<GlobalTypeVar>(global_type_vars);
}
template<typename T> template<typename T>
tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) { tvm::Array<T> concat(const tvm::Array<T>& l, const tvm::Array<T>& r) {
tvm::Array<T> ret(l); tvm::Array<T> ret(l);
...@@ -151,6 +151,22 @@ void ModuleNode::Add(const GlobalVar& var, ...@@ -151,6 +151,22 @@ void ModuleNode::Add(const GlobalVar& var,
AddUnchecked(var, checked_func); AddUnchecked(var, checked_func);
} }
void ModuleNode::AddUnchecked(const GlobalVar& var,
const Function& func) {
auto mod = GetRef<Module>(this);
this->functions.Set(var, func);
auto it = global_var_map_.find(var->name_hint);
if (it != global_var_map_.end()) {
CHECK_EQ((*it).second, var);
} else {
CHECK(global_var_map_.count(var->name_hint) == 0)
<< "Duplicate global function name " << var->name_hint;
}
global_var_map_.Set(var->name_hint, var);
}
void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) { void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) {
// We hash the global type var name to use as a globally unique prefix for tags. // We hash the global type var name to use as a globally unique prefix for tags.
// The hash will be used as the most significant byte of the tag, with the index of // The hash will be used as the most significant byte of the tag, with the index of
...@@ -163,25 +179,33 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& ...@@ -163,25 +179,33 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData&
} }
} }
void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) { void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type, bool update) {
this->type_definitions.Set(var, type); AddDefUnchecked(var, type, update);
// set global type var map
CHECK(!global_type_var_map_.count(var->var->name_hint))
<< "Duplicate global type definition name " << var->var->name_hint;
global_type_var_map_.Set(var->var->name_hint, var);
RegisterConstructors(var, type);
// need to kind check at the end because the check can look up // need to kind check at the end because the check can look up
// a definition potentially // a definition potentially
CHECK(KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData) CHECK(KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData)
<< "Invalid or malformed typedata given to module: " << type; << "Invalid or malformed typedata given to module: " << type;
} }
void ModuleNode::AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type, bool update) {
this->type_definitions.Set(var, type);
if (!update) {
// set global type var map
CHECK(global_type_var_map_.count(var->var->name_hint) == 0)
<< "Duplicate global type definition name " << var->var->name_hint;
}
global_type_var_map_.Set(var->var->name_hint, var);
RegisterConstructors(var, type);
}
void ModuleNode::Update(const GlobalVar& var, const Function& func) { void ModuleNode::Update(const GlobalVar& var, const Function& func) {
this->Add(var, func, true); this->Add(var, func, true);
} }
void ModuleNode::UpdateDef(const GlobalTypeVar& var, const TypeData& type) {
this->AddDef(var, type, true);
}
void ModuleNode::Remove(const GlobalVar& var) { void ModuleNode::Remove(const GlobalVar& var) {
auto functions_node = this->functions.CopyOnWrite(); auto functions_node = this->functions.CopyOnWrite();
functions_node->data.erase(var.node_); functions_node->data.erase(var.node_);
...@@ -226,9 +250,20 @@ Constructor ModuleNode::LookupTag(const int32_t tag) { ...@@ -226,9 +250,20 @@ Constructor ModuleNode::LookupTag(const int32_t tag) {
} }
void ModuleNode::Update(const Module& mod) { void ModuleNode::Update(const Module& mod) {
// add functions and type defs. we add them unchecked first, so all definitions
// can reference each other, independent of the order in which they were defined.
for (auto pair : mod->functions) {
this->AddUnchecked(pair.first, pair.second);
}
for (auto pair : mod->type_definitions) {
this->AddDefUnchecked(pair.first, pair.second);
}
for (auto pair : mod->functions) { for (auto pair : mod->functions) {
this->Update(pair.first, pair.second); this->Update(pair.first, pair.second);
} }
for (auto pair : mod->type_definitions) {
this->UpdateDef(pair.first, pair.second);
}
} }
Module ModuleNode::FromExpr( Module ModuleNode::FromExpr(
...@@ -257,14 +292,7 @@ void ModuleNode::Import(const std::string& path) { ...@@ -257,14 +292,7 @@ void ModuleNode::Import(const std::string& path) {
std::istreambuf_iterator<char>(src_file), std::istreambuf_iterator<char>(src_file),
std::istreambuf_iterator<char>() }; std::istreambuf_iterator<char>() };
auto mod_to_import = FromText(file_contents, path); auto mod_to_import = FromText(file_contents, path);
Update(mod_to_import);
for (auto func : mod_to_import->functions) {
this->Add(func.first, func.second, false);
}
for (auto type : mod_to_import->type_definitions) {
this->AddDef(type.first, type.second);
}
} }
} }
...@@ -315,6 +343,12 @@ TVM_REGISTER_API("relay._module.Module_AddDef") ...@@ -315,6 +343,12 @@ TVM_REGISTER_API("relay._module.Module_AddDef")
TVM_REGISTER_API("relay._module.Module_GetGlobalVar") TVM_REGISTER_API("relay._module.Module_GetGlobalVar")
.set_body_method<Module>(&ModuleNode::GetGlobalVar); .set_body_method<Module>(&ModuleNode::GetGlobalVar);
TVM_REGISTER_API("relay._module.Module_GetGlobalVars")
.set_body_method<Module>(&ModuleNode::GetGlobalVars);
TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVars")
.set_body_method<Module>(&ModuleNode::GetGlobalTypeVars);
TVM_REGISTER_API("relay._module.Module_ContainGlobalVar") TVM_REGISTER_API("relay._module.Module_ContainGlobalVar")
.set_body_method<Module>(&ModuleNode::ContainGlobalVar); .set_body_method<Module>(&ModuleNode::ContainGlobalVar);
......
...@@ -570,7 +570,13 @@ class PrettyPrinter : ...@@ -570,7 +570,13 @@ class PrettyPrinter :
} else { } else {
doc << Print(op->op); doc << Print(op->op);
} }
return doc << "(" << PrintSep(args) << ")";
if (cons_node && cons_node->inputs.size() == 0) {
// don't print as a call if it's a 0-arity cons
return doc;
} else {
return doc << "(" << PrintSep(args) << ")";
}
} }
Doc VisitExpr_(const RefCreateNode* op) final { Doc VisitExpr_(const RefCreateNode* op) final {
...@@ -641,6 +647,17 @@ class PrettyPrinter : ...@@ -641,6 +647,17 @@ class PrettyPrinter :
return doc; return doc;
} }
Doc VisitPattern_(const PatternTupleNode* pt) final {
Doc doc;
doc << "(";
std::vector<Doc> pats;
for (const auto& pat : pt->patterns) {
pats.push_back(Print(pat));
}
doc << PrintSep(pats) << ")";
return doc;
}
Doc VisitPattern_(const PatternWildcardNode* pw) final { Doc VisitPattern_(const PatternWildcardNode* pw) final {
return Doc("_"); return Doc("_");
} }
......
...@@ -800,12 +800,13 @@ def test_adt_cons_expr(): ...@@ -800,12 +800,13 @@ def test_adt_cons_expr():
%s %s
def @make_singleton(%%x: int32) -> List[int32] { def @make_singleton(%%x: int32) -> List[int32] {
Cons(%%x, Nil()) Cons(%%x, Nil)
} }
""" % LIST_DEFN, """ % LIST_DEFN,
mod mod
) )
@raises_parse_error @raises_parse_error
def test_duplicate_adt_defn(): def test_duplicate_adt_defn():
parse_text( parse_text(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment