Commit ca0292d8 by Logan Weber Committed by Jared Roesch

[Relay] Add ADTs to text format (#3863)

* Getting closer to having ADT defs

* ADT defs working probly

* Match parsing basipally done

* came to earth in a silver chrome UFO

* match finished?

* All tests but newest are passing

* ADT constructors work

now cleanup?

* Cleanup round 1

* Cleanup round 2

* Cleanup round 3

* Cleanup round 4

* Cleanup round 6

* Cleanup round 7

* Lil grammar fix

* Remove ANTLR Java files

* Lint roller

* Lint roller

* Address feedback

* Test completeness in match test

* Remove unused imports

* Lint roller

* Switch to Rust-style ADT syntax

* Lil fix

* Add dummy `extern type` handler

* Add type arg to test

* Update prelude semantic version

* Repair test

* Fix graph var handling in match

* Revert 's/graph_equal/is_unifiable' change
parent a103c4ee
...@@ -165,6 +165,13 @@ class ModuleNode : public RelayNode { ...@@ -165,6 +165,13 @@ class ModuleNode : public RelayNode {
TVM_DLL TypeData LookupDef(const std::string& var) const; TVM_DLL TypeData LookupDef(const std::string& var) const;
/*! /*!
* \brief Check if a global type definition exists
* \param var The name of the global type definition.
* \return Whether the definition exists.
*/
TVM_DLL bool HasDef(const std::string& var) const;
/*!
* \brief Look up a constructor by its tag. * \brief Look up a constructor by its tag.
* \param tag The tag for the constructor. * \param tag The tag for the constructor.
* \return The constructor object. * \return The constructor object.
......
...@@ -17,11 +17,15 @@ ...@@ -17,11 +17,15 @@
* under the License. * under the License.
*/ */
// list = *, seq = ? /*
* NOTE: The `USE_ANTLR` option in `config.cmake` must be enabled in order for
* changes in this file to be reflected by the parser.
* NOTE: All upper-case rules are *lexer* rules and all camel-case rules are *parser* rules.
*/
grammar Relay; grammar Relay;
SEMVER: 'v0.0.3' ; SEMVER: 'v0.0.4' ;
// Lexing // Lexing
// comments // comments
...@@ -49,13 +53,8 @@ BOOL_LIT ...@@ -49,13 +53,8 @@ BOOL_LIT
| 'False' | 'False'
; ;
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)*; CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)* ;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
DATATYPE : 'int64';
// non-negative floats // non-negative floats
fragment PREFLOAT : NAT ('.' NAT)? EXP?; // 1.35, 1.35E-9, 0.3, 4.5, 1, 1e10 3e4 fragment PREFLOAT : NAT ('.' NAT)? EXP?; // 1.35, 1.35E-9, 0.3, 4.5, 1, 1e10 3e4
...@@ -74,7 +73,11 @@ METADATA: 'METADATA:' .*; ...@@ -74,7 +73,11 @@ METADATA: 'METADATA:' .*;
// A Relay program is a list of global definitions or an expression. // A Relay program is a list of global definitions or an expression.
prog: SEMVER (defn* | expr) METADATA? EOF ; prog: SEMVER (defn* | expr) METADATA? EOF ;
// option: 'set' ident BOOL_LIT ; // Covers both operator and type idents
generalIdent: CNAME ('.' CNAME)*;
globalVar: '@' CNAME ;
localVar: '%' ('_' | CNAME) ;
graphVar: '%' NAT ;
exprList: (expr (',' expr)*)?; exprList: (expr (',' expr)*)?;
callList callList
...@@ -85,7 +88,6 @@ callList ...@@ -85,7 +88,6 @@ callList
expr expr
// operators // operators
: '(' expr ')' # paren : '(' expr ')' # paren
| '{' expr '}' # paren
// function application // function application
| expr '(' callList ')' # call | expr '(' callList ')' # call
| '-' expr # neg | '-' expr # neg
...@@ -99,53 +101,74 @@ expr ...@@ -99,53 +101,74 @@ expr
| '(' ')' # tuple | '(' ')' # tuple
| '(' expr ',' ')' # tuple | '(' expr ',' ')' # tuple
| '(' expr (',' expr)+ ')' # tuple | '(' expr (',' expr)+ ')' # tuple
| expr '.' NAT # projection
| '[' (expr (',' expr)*)? ']' # tensor | '[' (expr (',' expr)*)? ']' # tensor
| 'if' '(' expr ')' body 'else' body # ifElse | 'if' '(' expr ')' body 'else' body # ifElse
| matchType '(' expr ')' '{' matchClauseList? '}' # match
| expr '.' NAT # projection
// sequencing // sequencing
| 'let' var '=' expr ';' expr # let | 'let' var '=' expr ';' expr # let
// sugar for let %_ = expr; expr // sugar for let %_ = expr; expr
| expr ';;' expr # let | expr ';;' expr # let
| GRAPH_VAR '=' expr ';' expr # graph | graphVar '=' expr ';' expr # graph
| ident # identExpr | ident # identExpr
| scalar # scalarExpr | scalar # scalarExpr
| meta # metaExpr | meta # metaExpr
| QUOTED_STRING # stringExpr | QUOTED_STRING # stringExpr
; ;
func: 'fn' typeParamList? '(' argList ')' ('->' type_)? body ; func: 'fn' typeParamList? '(' argList ')' ('->' typeExpr)? body ;
defn: 'def' ident typeParamList? '(' argList ')' ('->' type_)? body ; defn
: 'def' globalVar typeParamList? '(' argList ')' ('->' typeExpr)? body # funcDefn
| 'extern' 'type' generalIdent typeParamList? # externAdtDefn
| 'type' generalIdent typeParamList? '{' adtConsDefnList? '}' # adtDefn
;
constructorName: CNAME ;
adtConsDefnList: adtConsDefn (',' adtConsDefn)* ','? ;
adtConsDefn: constructorName ('(' typeExpr (',' typeExpr)* ')')? ;
matchClauseList: matchClause (',' matchClause)* ','? ;
matchClause: constructorName patternList? '=>' ('{' expr '}' | expr) ;
// complete or incomplete match, respectively
matchType : 'match' | 'match?' ;
patternList: '(' pattern (',' pattern)* ')';
pattern
: '_'
| localVar (':' typeExpr)?
;
adtCons: constructorName adtConsParamList? ;
adtConsParamList: '(' adtConsParam (',' adtConsParam)* ')' ;
adtConsParam: localVar | constructorName ;
argList argList
: varList # argNoAttr : varList # argNoAttr
| (var ',')* attrSeq # argWithAttr | (var ',')* attrSeq # argWithAttr
; ;
varList: (var (',' var)*)?; varList: (var (',' var)*)? ;
var: LOCAL_VAR (':' type_)?; var: localVar (':' typeExpr)? ;
attrSeq: attr (',' attr)*; attrSeq: attr (',' attr)* ;
attr: CNAME '=' expr ; attr: CNAME '=' expr ;
typeParamList typeExpr
: '[' ']'
| '[' ident (',' ident)* ']'
;
type_
: '(' ')' # tupleType : '(' ')' # tupleType
| '(' type_ ',' ')' # tupleType | '(' typeExpr ',' ')' # tupleType
| '(' type_ (',' type_)+ ')' # tupleType | '(' typeExpr (',' typeExpr)+ ')' # tupleType
| typeIdent # typeIdentType | generalIdent typeParamList # typeCallType
| 'Tensor' '[' shapeList ',' type_ ']' # tensorType | generalIdent # typeIdentType
| 'fn' typeParamList? '(' (type_ (',' type_)*)? ')' '->' type_ # funcType | 'Tensor' '[' shapeList ',' typeExpr ']' # tensorType
| 'fn' typeParamList? '(' (typeExpr (',' typeExpr)*)? ')' '->' typeExpr # funcType
| '_' # incompleteType | '_' # incompleteType
| NAT # intType
; ;
typeParamList: '[' generalIdent (',' generalIdent)* ']' ;
shapeList shapeList
: '(' shape (',' shape)+ ')' : '(' ')'
| '(' ')' | '(' shape (',' shape)+ ')'
| shape | shape
; ;
...@@ -157,12 +180,6 @@ shape ...@@ -157,12 +180,6 @@ shape
| NAT # intShape | NAT # intShape
; ;
typeIdent : CNAME;
// int8, int16, int32, int64
// uint8, uint16, uint32, uint64
// float16, float32, float64
// bool
body: '{' expr '}' ; body: '{' expr '}' ;
scalar scalar
...@@ -172,8 +189,8 @@ scalar ...@@ -172,8 +189,8 @@ scalar
; ;
ident ident
: opIdent : generalIdent
| GLOBAL_VAR | globalVar
| LOCAL_VAR | localVar
| GRAPH_VAR | graphVar
; ;
This source diff could not be displayed because it is too large. You can view the blob instead.
# Generated from /workspace/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.1 # Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2
from antlr4 import * from antlr4 import *
if __name__ is not None and "." in __name__: if __name__ is not None and "." in __name__:
from .RelayParser import RelayParser from .RelayParser import RelayParser
...@@ -9,13 +9,28 @@ else: ...@@ -9,13 +9,28 @@ else:
class RelayVisitor(ParseTreeVisitor): class RelayVisitor(ParseTreeVisitor):
# Visit a parse tree produced by RelayParser#opIdent. # Visit a parse tree produced by RelayParser#prog.
def visitOpIdent(self, ctx:RelayParser.OpIdentContext): def visitProg(self, ctx:RelayParser.ProgContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#prog. # Visit a parse tree produced by RelayParser#generalIdent.
def visitProg(self, ctx:RelayParser.ProgContext): def visitGeneralIdent(self, ctx:RelayParser.GeneralIdentContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#globalVar.
def visitGlobalVar(self, ctx:RelayParser.GlobalVarContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#localVar.
def visitLocalVar(self, ctx:RelayParser.LocalVarContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#graphVar.
def visitGraphVar(self, ctx:RelayParser.GraphVarContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -44,6 +59,11 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -44,6 +59,11 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#match.
def visitMatch(self, ctx:RelayParser.MatchContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#tensor. # Visit a parse tree produced by RelayParser#tensor.
def visitTensor(self, ctx:RelayParser.TensorContext): def visitTensor(self, ctx:RelayParser.TensorContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -114,8 +134,73 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -114,8 +134,73 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#defn. # Visit a parse tree produced by RelayParser#funcDefn.
def visitDefn(self, ctx:RelayParser.DefnContext): def visitFuncDefn(self, ctx:RelayParser.FuncDefnContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#externAdtDefn.
def visitExternAdtDefn(self, ctx:RelayParser.ExternAdtDefnContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtDefn.
def visitAdtDefn(self, ctx:RelayParser.AdtDefnContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#constructorName.
def visitConstructorName(self, ctx:RelayParser.ConstructorNameContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtConsDefnList.
def visitAdtConsDefnList(self, ctx:RelayParser.AdtConsDefnListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtConsDefn.
def visitAdtConsDefn(self, ctx:RelayParser.AdtConsDefnContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#matchClauseList.
def visitMatchClauseList(self, ctx:RelayParser.MatchClauseListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#matchClause.
def visitMatchClause(self, ctx:RelayParser.MatchClauseContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#matchType.
def visitMatchType(self, ctx:RelayParser.MatchTypeContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#patternList.
def visitPatternList(self, ctx:RelayParser.PatternListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#pattern.
def visitPattern(self, ctx:RelayParser.PatternContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtCons.
def visitAdtCons(self, ctx:RelayParser.AdtConsContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtConsParamList.
def visitAdtConsParamList(self, ctx:RelayParser.AdtConsParamListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtConsParam.
def visitAdtConsParam(self, ctx:RelayParser.AdtConsParamContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -149,13 +234,13 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -149,13 +234,13 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#typeParamList. # Visit a parse tree produced by RelayParser#tupleType.
def visitTypeParamList(self, ctx:RelayParser.TypeParamListContext): def visitTupleType(self, ctx:RelayParser.TupleTypeContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#tupleType. # Visit a parse tree produced by RelayParser#typeCallType.
def visitTupleType(self, ctx:RelayParser.TupleTypeContext): def visitTypeCallType(self, ctx:RelayParser.TypeCallTypeContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -179,8 +264,8 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -179,8 +264,8 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#intType. # Visit a parse tree produced by RelayParser#typeParamList.
def visitIntType(self, ctx:RelayParser.IntTypeContext): def visitTypeParamList(self, ctx:RelayParser.TypeParamListContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -209,11 +294,6 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -209,11 +294,6 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#typeIdent.
def visitTypeIdent(self, ctx:RelayParser.TypeIdentContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#body. # Visit a parse tree produced by RelayParser#body.
def visitBody(self, ctx:RelayParser.BodyContext): def visitBody(self, ctx:RelayParser.BodyContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ */
v0.0.3 v0.0.4
def @id[a](%x: a) -> a { def @id[a](%x: a) -> a {
%x %x
......
...@@ -70,7 +70,10 @@ class AlphaEqualHandler: ...@@ -70,7 +70,10 @@ class AlphaEqualHandler:
} }
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false; if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) { for (const auto& p : lhsm->type_definitions) {
if (!Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) return false; if (!rhsm->HasDef(p.first->var->name_hint) ||
!Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) {
return false;
}
} }
return true; return true;
} }
...@@ -288,7 +291,7 @@ class AlphaEqualHandler: ...@@ -288,7 +291,7 @@ class AlphaEqualHandler:
} }
bool VisitType_(const GlobalTypeVarNode* lhs, const Type& other) final { bool VisitType_(const GlobalTypeVarNode* lhs, const Type& other) final {
return GetRef<Type>(lhs) == other; return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
} }
bool VisitType_(const TypeCallNode* lhs, const Type& other) final { bool VisitType_(const TypeCallNode* lhs, const Type& other) final {
...@@ -307,6 +310,26 @@ class AlphaEqualHandler: ...@@ -307,6 +310,26 @@ class AlphaEqualHandler:
return true; return true;
} }
bool VisitType_(const TypeDataNode* lhs, const Type& other) final {
const TypeDataNode* rhs = other.as<TypeDataNode>();
if (rhs == nullptr
|| lhs->type_vars.size() != rhs->type_vars.size()
|| !TypeEqual(lhs->header, rhs->header)) {
return false;
}
for (size_t i = 0; i < lhs->type_vars.size(); ++i) {
if (!TypeEqual(lhs->type_vars[i], rhs->type_vars[i])) {
return false;
}
}
for (size_t i = 0; i < lhs->constructors.size(); ++i) {
if (!ExprEqual(lhs->constructors[i], rhs->constructors[i])) {
return false;
}
}
return true;
}
// Expr equal checking. // Expr equal checking.
bool NDArrayEqual(const runtime::NDArray& lhs, bool NDArrayEqual(const runtime::NDArray& lhs,
const runtime::NDArray& rhs) { const runtime::NDArray& rhs) {
...@@ -485,7 +508,10 @@ class AlphaEqualHandler: ...@@ -485,7 +508,10 @@ class AlphaEqualHandler:
} }
bool VisitExpr_(const ConstructorNode* lhs, const Expr& other) final { bool VisitExpr_(const ConstructorNode* lhs, const Expr& other) final {
return GetRef<Expr>(lhs) == other; if (const ConstructorNode* rhs = other.as<ConstructorNode>()) {
return lhs->name_hint == rhs->name_hint;
}
return false;
} }
bool ClauseEqual(const Clause& lhs, const Clause& rhs) { bool ClauseEqual(const Clause& lhs, const Clause& rhs) {
...@@ -582,7 +608,7 @@ TVM_REGISTER_API("relay._make._alpha_equal") ...@@ -582,7 +608,7 @@ TVM_REGISTER_API("relay._make._alpha_equal")
TVM_REGISTER_API("relay._make._assert_alpha_equal") TVM_REGISTER_API("relay._make._assert_alpha_equal")
.set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) { .set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b); bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " is not alpha equal"; CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal";
}); });
TVM_REGISTER_API("relay._make._graph_equal") TVM_REGISTER_API("relay._make._graph_equal")
...@@ -593,7 +619,7 @@ TVM_REGISTER_API("relay._make._graph_equal") ...@@ -593,7 +619,7 @@ TVM_REGISTER_API("relay._make._graph_equal")
TVM_REGISTER_API("relay._make._assert_graph_equal") TVM_REGISTER_API("relay._make._assert_graph_equal")
.set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) { .set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b); bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " is not graph equal"; CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal";
}); });
} // namespace relay } // namespace relay
......
...@@ -206,6 +206,11 @@ TypeData ModuleNode::LookupDef(const std::string& name) const { ...@@ -206,6 +206,11 @@ TypeData ModuleNode::LookupDef(const std::string& name) const {
return this->LookupDef(id); return this->LookupDef(id);
} }
bool ModuleNode::HasDef(const std::string& name) const {
auto it = global_type_var_map_.find(name);
return it != global_type_var_map_.end();
}
Constructor ModuleNode::LookupTag(const int32_t tag) { Constructor ModuleNode::LookupTag(const int32_t tag) {
auto it = constructor_tag_map_.find(tag); auto it = constructor_tag_map_.find(tag);
CHECK(it != constructor_tag_map_.end()) CHECK(it != constructor_tag_map_.end())
......
...@@ -44,6 +44,8 @@ ...@@ -44,6 +44,8 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
static const char* kSemVer = "v0.0.4";
Doc Brace(const Doc& d, Doc Brace(const Doc& d,
const std::string& open = "{", const std::string& open = "{",
const std::string& close = "}", const std::string& close = "}",
...@@ -239,6 +241,8 @@ class PrettyPrinter : ...@@ -239,6 +241,8 @@ class PrettyPrinter :
return PrintExpr(Downcast<Expr>(node), meta, try_inline); return PrintExpr(Downcast<Expr>(node), meta, try_inline);
} else if (node.as_derived<TypeNode>()) { } else if (node.as_derived<TypeNode>()) {
return PrintType(Downcast<Type>(node), meta); return PrintType(Downcast<Type>(node), meta);
} else if (node.as_derived<PatternNode>()) {
return PrintPattern(Downcast<Pattern>(node), meta);
} else if (node.as_derived<ModuleNode>()) { } else if (node.as_derived<ModuleNode>()) {
return PrintMod(Downcast<Module>(node)); return PrintMod(Downcast<Module>(node));
} else { } else {
...@@ -313,7 +317,7 @@ class PrettyPrinter : ...@@ -313,7 +317,7 @@ class PrettyPrinter :
if (name.length() == 0 || !std::isalpha(name[0])) { if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name; name = "t" + name;
} }
Doc val = GetUniqueName("%" + name); Doc val = GetUniqueName(name);
memo_type_[var] = val; memo_type_[var] = val;
if (var->kind != kType) { if (var->kind != kType) {
val << ": " << Print(var->kind); val << ": " << Print(var->kind);
...@@ -347,13 +351,17 @@ class PrettyPrinter : ...@@ -347,13 +351,17 @@ class PrettyPrinter :
} }
bool IsUnique(const Expr& expr) { bool IsUnique(const Expr& expr) {
return !(dg_.expr_node.at(expr)->parents.head && auto it = dg_.expr_node.find(expr);
dg_.expr_node.at(expr)->parents.head->next); if (it == dg_.expr_node.end()) {
return true;
} else {
return !(it->second->parents.head && it->second->parents.head->next);
}
} }
bool AlwaysInline(const Expr& expr) { bool AlwaysInline(const Expr& expr) {
return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || expr.as<OpNode>() ||
expr.as<OpNode>() || expr.as<VarNode>(); expr.as<VarNode>() || expr.as<ConstructorNode>();
} }
//------------------------------------ //------------------------------------
...@@ -380,9 +388,9 @@ class PrettyPrinter : ...@@ -380,9 +388,9 @@ class PrettyPrinter :
} else if (!inline_expr && expr.as<LetNode>()) { } else if (!inline_expr && expr.as<LetNode>()) {
// wrap GNFed let in brackets // wrap GNFed let in brackets
Doc body; Doc body;
printed_expr << "{"; printed_expr << "(";
printed_expr << Indent(2, body << PrintNewLine() << VisitExpr(expr)) << PrintNewLine(); printed_expr << Indent(2, body << PrintNewLine() << VisitExpr(expr)) << PrintNewLine();
printed_expr << "}"; printed_expr << ")";
} else { } else {
printed_expr = VisitExpr(expr); printed_expr = VisitExpr(expr);
} }
...@@ -483,13 +491,13 @@ class PrettyPrinter : ...@@ -483,13 +491,13 @@ class PrettyPrinter :
Doc doc; Doc doc;
doc << prefix; doc << prefix;
if (fn->type_params.size() > 0) { if (fn->type_params.size() > 0) {
doc << "<"; doc << "[";
std::vector<Doc> type_params; std::vector<Doc> type_params;
for (const TypeVar& tv : fn->type_params) { for (const TypeVar& tv : fn->type_params) {
type_params.push_back(AllocTypeVar(tv)); type_params.push_back(Doc(tv->var->name_hint));
} }
doc << PrintSep(type_params); doc << PrintSep(type_params);
doc << ">"; doc << "]";
} }
doc << "("; doc << "(";
std::vector<Doc> params; std::vector<Doc> params;
...@@ -510,6 +518,15 @@ class PrettyPrinter : ...@@ -510,6 +518,15 @@ class PrettyPrinter :
Doc PrintMod(const Module& mod) { Doc PrintMod(const Module& mod) {
Doc doc; Doc doc;
int counter = 0; int counter = 0;
// type definitions
for (const auto& kv : mod->type_definitions) {
if (counter++ != 0) {
doc << PrintNewLine();
}
doc << Print(kv.second);
doc << PrintNewLine();
}
// functions
for (const auto& kv : mod->functions) { for (const auto& kv : mod->functions) {
dg_ = DependencyGraph::Create(&arena_, kv.second); dg_ = DependencyGraph::Create(&arena_, kv.second);
...@@ -547,7 +564,12 @@ class PrettyPrinter : ...@@ -547,7 +564,12 @@ class PrettyPrinter :
for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) { for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
args.push_back(d); args.push_back(d);
} }
const auto* cons_node = op->op.as<ConstructorNode>();
if (cons_node) {
doc << cons_node->name_hint;
} else {
doc << Print(op->op); doc << Print(op->op);
}
return doc << "(" << PrintSep(args) << ")"; return doc << "(" << PrintSep(args) << ")";
} }
...@@ -570,27 +592,57 @@ class PrettyPrinter : ...@@ -570,27 +592,57 @@ class PrettyPrinter :
// TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs. // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs.
Doc doc; Doc doc;
Doc body; Doc body;
doc << "match " << Print(op->data) << " "; doc << "match";
doc << "{"; if (!op->complete) {
std::vector<Doc> clauses; doc << "?";
}
doc << " (" << Print(op->data) << ") {";
std::vector<Doc> clause_docs;
for (const auto& clause : op->clauses) { for (const auto& clause : op->clauses) {
Doc clause_doc; Doc clause_doc;
clauses.push_back(clause_doc << Print(clause->lhs) << " -> " clause_doc << PrintPattern(clause->lhs, false) << " => ";
<< Print(clause->rhs)); Doc rhs_doc = PrintScope(clause->rhs);
if (clause->rhs.as<LetNode>()) {
// only add braces if there are multiple lines on the rhs
rhs_doc = Brace(rhs_doc);
} }
doc << Indent(2, body << PrintNewLine() << PrintSep(clauses, PrintNewLine())) << PrintNewLine(); clause_doc << rhs_doc << ",";
doc << "}"; clause_docs.push_back(clause_doc);
}
doc << Indent(2, body << PrintNewLine() << PrintSep(clause_docs, PrintNewLine()))
<< PrintNewLine() << "}";
return doc; return doc;
} }
Doc PrintPattern(const Pattern& pattern, bool meta) {
auto it = memo_pattern_.find(pattern);
if (it != memo_pattern_.end()) return it->second;
Doc printed_pattern;
if (meta) {
printed_pattern = meta_.GetMetaNode(GetRef<NodeRef>(pattern.get()));
} else {
printed_pattern = VisitPattern(pattern);
}
memo_pattern_[pattern] = printed_pattern;
return printed_pattern;
}
Doc VisitPattern_(const PatternConstructorNode* p) final { Doc VisitPattern_(const PatternConstructorNode* p) final {
Doc doc; Doc doc;
doc << p->constructor->name_hint << "("; doc << p->constructor->name_hint;
if (!p->patterns.empty()) {
doc << "(";
std::vector<Doc> pats; std::vector<Doc> pats;
for (const auto& pat : p->patterns) { for (const auto& pat : p->patterns) {
pats.push_back(Print(pat)); pats.push_back(Print(pat));
} }
return doc << PrintSep(pats) << ")"; doc << PrintSep(pats) << ")";
}
return doc;
}
Doc VisitPattern_(const PatternWildcardNode* pw) final {
return Doc("_");
} }
Doc VisitPattern_(const PatternVarNode* pv) final { Doc VisitPattern_(const PatternVarNode* pv) final {
...@@ -598,7 +650,17 @@ class PrettyPrinter : ...@@ -598,7 +650,17 @@ class PrettyPrinter :
} }
Doc VisitExpr_(const ConstructorNode* n) final { Doc VisitExpr_(const ConstructorNode* n) final {
return Doc(n->name_hint); Doc doc;
doc << n->name_hint;
if (n->inputs.size() != 0) {
doc << "(";
std::vector<Doc> inputs;
for (Type input : n->inputs) {
inputs.push_back(Print(input));
}
doc << PrintSep(inputs) << ")";
}
return doc;
} }
//------------------------------------ //------------------------------------
...@@ -623,7 +685,7 @@ class PrettyPrinter : ...@@ -623,7 +685,7 @@ class PrettyPrinter :
} }
Doc VisitType_(const TypeVarNode* node) final { Doc VisitType_(const TypeVarNode* node) final {
return AllocTypeVar(GetRef<TypeVar>(node)); return Doc(node->var->name_hint);
} }
Doc VisitType_(const GlobalTypeVarNode* node) final { Doc VisitType_(const GlobalTypeVarNode* node) final {
...@@ -675,13 +737,13 @@ class PrettyPrinter : ...@@ -675,13 +737,13 @@ class PrettyPrinter :
Doc doc; Doc doc;
doc << "fn "; doc << "fn ";
if (node->type_params.size() != 0) { if (node->type_params.size() != 0) {
doc << "<"; doc << "[";
std::vector<Doc> type_params; std::vector<Doc> type_params;
for (Type type_param : node->type_params) { for (Type type_param : node->type_params) {
type_params.push_back(Print(type_param)); type_params.push_back(Print(type_param));
} }
doc << PrintSep(type_params); doc << PrintSep(type_params);
doc << ">"; doc << "]";
} }
std::vector<Doc> arg_types; std::vector<Doc> arg_types;
for (Type arg_type : node->arg_types) { for (Type arg_type : node->arg_types) {
...@@ -695,6 +757,37 @@ class PrettyPrinter : ...@@ -695,6 +757,37 @@ class PrettyPrinter :
return doc << "ref(" << Print(node->value) << ")"; return doc << "ref(" << Print(node->value) << ")";
} }
Doc VisitType_(const TypeDataNode* node) final {
Doc doc;
doc << "type " << Print(node->header);
// type vars
if (node->type_vars.size() != 0) {
doc << "[";
std::vector<Doc> type_vars;
for (Type type_var : node->type_vars) {
type_vars.push_back(Print(type_var));
}
doc << PrintSep(type_vars) << "]";
}
doc << " ";
std::vector<Doc> constructor_docs;
for (Constructor constructor : node->constructors) {
constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true));
}
Doc separator;
separator << "," << PrintNewLine();
Doc adt_body;
adt_body << PrintSep(constructor_docs, separator);
// add trailing comma if there are any constructors
if (!constructor_docs.empty()) {
adt_body << ",";
}
doc << Brace(adt_body);
return doc;
}
//------------------------------------ //------------------------------------
// Overload of Attr printing functions // Overload of Attr printing functions
//------------------------------------ //------------------------------------
...@@ -758,6 +851,8 @@ class PrettyPrinter : ...@@ -758,6 +851,8 @@ class PrettyPrinter :
std::unordered_map<Expr, Doc, NodeHash, NodeEqual> memo_; std::unordered_map<Expr, Doc, NodeHash, NodeEqual> memo_;
/*! \brief Map from Type to Doc */ /*! \brief Map from Type to Doc */
std::unordered_map<Type, Doc, NodeHash, NodeEqual> memo_type_; std::unordered_map<Type, Doc, NodeHash, NodeEqual> memo_type_;
/*! \brief Map from Type to Doc */
std::unordered_map<Pattern, Doc, NodeHash, NodeEqual> memo_pattern_;
/*! \brief name allocation map */ /*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_; std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief meta data context */ /*! \brief meta data context */
...@@ -861,7 +956,7 @@ std::string PrettyPrint_(const NodeRef& node, ...@@ -861,7 +956,7 @@ std::string PrettyPrint_(const NodeRef& node,
bool show_meta_data, bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) { runtime::TypedPackedFunc<std::string(Expr)> annotate) {
Doc doc; Doc doc;
doc << "v0.0.3" << PrintNewLine() doc << kSemVer << PrintNewLine()
<< PrettyPrinter(show_meta_data, annotate).PrintFinal(node); << PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
return doc.str(); return doc.str();
} }
......
...@@ -774,7 +774,6 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { ...@@ -774,7 +774,6 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
bool update_missing_type_annotation_{true}; bool update_missing_type_annotation_{true};
}; };
Expr TypeInferencer::Infer(Expr expr) { Expr TypeInferencer::Infer(Expr expr) {
// Step 1: Populate the constraints. // Step 1: Populate the constraints.
GetType(expr); GetType(expr);
......
...@@ -16,14 +16,14 @@ ...@@ -16,14 +16,14 @@
# under the License. # under the License.
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.analysis import alpha_equal, assert_alpha_equal from tvm.relay.analysis import graph_equal, assert_graph_equal
from nose.tools import nottest, raises from nose.tools import nottest, raises
from numpy import isclose from numpy import isclose
from typing import Union from typing import Union
from functools import wraps from functools import wraps
raises_parse_error = raises(tvm._ffi.base.TVMError) raises_parse_error = raises(tvm._ffi.base.TVMError)
SEMVER = "v0.0.3" SEMVER = "v0.0.4"
BINARY_OPS = { BINARY_OPS = {
"*": relay.multiply, "*": relay.multiply,
...@@ -60,20 +60,29 @@ TYPES = { ...@@ -60,20 +60,29 @@ TYPES = {
"float16x4", "float16x4",
} }
LIST_DEFN = """
type List[A] {
Cons(A, List[A]),
Nil,
}
"""
def roundtrip(expr): def roundtrip(expr):
x = relay.fromtext(str(expr)) x = relay.fromtext(str(expr))
assert_alpha_equal(x, expr) assert_graph_equal(x, expr)
def parse_text(code): def parse_text(code):
x = relay.fromtext(SEMVER + "\n" + code) expr = relay.fromtext(SEMVER + "\n" + code)
roundtrip(x) roundtrip(expr)
return x return expr
def parses_as(code, expr): def parses_as(code, expr):
# type: (str, relay.Expr) -> bool # type: (str, relay.Expr) -> bool
return alpha_equal(parse_text(code), expr) parsed = parse_text(code)
result = graph_equal(parsed, expr)
return result
def get_scalar(x): def get_scalar(x):
# type: (relay.Constant) -> (Union[float, int, bool]) # type: (relay.Constant) -> (Union[float, int, bool])
...@@ -168,13 +177,13 @@ def test_bin_op(): ...@@ -168,13 +177,13 @@ def test_bin_op():
def test_parens(): def test_parens():
assert alpha_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1")) assert graph_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1"))
assert not alpha_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)")) assert not graph_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)"))
def test_op_assoc(): def test_op_assoc():
assert alpha_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1")) assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1"))
assert alpha_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))")) assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))"))
@nottest @nottest
...@@ -239,7 +248,7 @@ def test_seq(): ...@@ -239,7 +248,7 @@ def test_seq():
) )
assert parses_as( assert parses_as(
"let %_ = { 1 }; ()", "let %_ = 1; ()",
relay.Let( relay.Let(
X, X,
relay.const(1), relay.const(1),
...@@ -249,13 +258,13 @@ def test_seq(): ...@@ -249,13 +258,13 @@ def test_seq():
def test_graph(): def test_graph():
code = "%0 = (); %1 = 1; (%0, %0, %1)"
assert parses_as( assert parses_as(
"%0 = (); %1 = 1; (%0, %0, %1)", code,
relay.Tuple([UNIT, UNIT, relay.const(1)]) relay.Tuple([UNIT, UNIT, relay.const(1)])
) )
assert not parses_as( assert not parses_as(
"%0 = (); %1 = 1; (%0, %0, %1)", code,
relay.Tuple([relay.Tuple([]), relay.Tuple([]), relay.const(1)]) relay.Tuple([relay.Tuple([]), relay.Tuple([]), relay.const(1)])
) )
...@@ -632,6 +641,236 @@ def test_tuple_type(): ...@@ -632,6 +641,236 @@ def test_tuple_type():
) )
) )
def test_adt_defn():
mod = relay.Module()
glob_typ_var = relay.GlobalTypeVar("Ayy")
prog = relay.TypeData(
glob_typ_var,
[],
[relay.Constructor("Nil", [], glob_typ_var)])
mod[glob_typ_var] = prog
assert parses_as(
"""
type Ayy { Nil }
""",
mod
)
def test_empty_adt_defn():
mod = relay.Module()
glob_typ_var = relay.GlobalTypeVar("Ayy")
prog = relay.TypeData(glob_typ_var, [], [])
mod[glob_typ_var] = prog
assert parses_as(
"""
type Ayy { }
""",
mod
)
def test_multiple_cons_defn():
mod = relay.Module()
list_var = relay.GlobalTypeVar("List")
typ_var = relay.TypeVar("A")
prog = relay.TypeData(
list_var,
[typ_var],
[
relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var),
relay.Constructor("Nil", [], list_var),
])
mod[list_var] = prog
assert parses_as(LIST_DEFN, mod)
def test_multiple_type_param_defn():
glob_typ_var = relay.GlobalTypeVar("Either")
typ_var_a = relay.TypeVar("A")
typ_var_b = relay.TypeVar("B")
prog = relay.TypeData(
glob_typ_var,
[typ_var_a, typ_var_b],
[
relay.Constructor("Left", [typ_var_a], glob_typ_var),
relay.Constructor("Right", [typ_var_b], glob_typ_var),
])
mod = relay.Module()
mod[glob_typ_var] = prog
assert parses_as(
"""
type Either[A, B] {
Left(A),
Right(B),
}
""",
mod
)
def test_match():
# pair each match keyword with whether it specifies a complete match or not
match_keywords = [("match", True), ("match?", False)]
for (match_keyword, is_complete) in match_keywords:
mod = relay.Module()
list_var = relay.GlobalTypeVar("List")
typ_var = relay.TypeVar("A")
cons_constructor = relay.Constructor(
"Cons", [typ_var, list_var(typ_var)], list_var)
nil_constructor = relay.Constructor("Nil", [], list_var)
list_def = relay.TypeData(
list_var,
[typ_var],
[cons_constructor, nil_constructor])
mod[list_var] = list_def
length_var = relay.GlobalVar("length")
typ_var = relay.TypeVar("A")
input_type = list_var(typ_var)
input_var = relay.Var("xs", input_type)
rest_var = relay.Var("rest")
cons_case = relay.Let(
_,
UNIT,
relay.add(relay.const(1), relay.Call(length_var, [rest_var])))
body = relay.Match(input_var,
[relay.Clause(
relay.PatternConstructor(
cons_constructor,
[relay.PatternWildcard(), relay.PatternVar(rest_var)]),
cons_case),
relay.Clause(
relay.PatternConstructor(nil_constructor, []),
relay.const(0))],
complete=is_complete
)
length_func = relay.Function(
[input_var],
body,
int32,
[typ_var]
)
mod[length_var] = length_func
assert parses_as(
"""
%s
def @length[A](%%xs: List[A]) -> int32 {
%s (%%xs) {
Cons(_, %%rest) => {
();;
1 + @length(%%rest)
},
Nil => 0,
}
}
""" % (LIST_DEFN, match_keyword),
mod
)
def test_adt_cons_expr():
mod = relay.Module()
list_var = relay.GlobalTypeVar("List")
typ_var = relay.TypeVar("A")
cons_constructor = relay.Constructor(
"Cons", [typ_var, list_var(typ_var)], list_var)
nil_constructor = relay.Constructor("Nil", [], list_var)
list_def = relay.TypeData(
list_var,
[typ_var],
[cons_constructor, nil_constructor])
mod[list_var] = list_def
make_singleton_var = relay.GlobalVar("make_singleton")
input_var = relay.Var("x", int32)
make_singleton_func = relay.Function(
[input_var],
cons_constructor(input_var, nil_constructor()),
list_var(int32)
)
mod[make_singleton_var] = make_singleton_func
assert parses_as(
"""
%s
def @make_singleton(%%x: int32) -> List[int32] {
Cons(%%x, Nil())
}
""" % LIST_DEFN,
mod
)
@raises_parse_error
def test_duplicate_adt_defn():
parse_text(
"""
%s
type List[A] {
Cons(A, List[A]),
Nil,
}
""" % LIST_DEFN
)
@raises_parse_error
def test_duplicate_adt_cons():
parse_text(
"""
type Ayy { Lmao }
type Haha { Lmao }
"""
)
@raises_parse_error
def test_duplicate_adt_cons_defn():
parse_text(
"""
type Ayy { Lmao }
type Lmao { Ayy }
"""
)
@raises_parse_error
def test_duplicate_global_var():
parse_text(
"""
def @id[A](%x: A) -> A { x }
def @id[A](%x: A) -> A { x }
"""
)
def test_extern_adt_defn():
# TODO(weberlo): update this test once extern is implemented
mod = relay.Module()
extern_var = relay.GlobalTypeVar("T")
typ_var = relay.TypeVar("A")
extern_def = relay.TypeData(extern_var, [typ_var], [])
mod[extern_var] = extern_def
assert parses_as(
"""
extern type T[A]
""",
mod
)
if __name__ == "__main__": if __name__ == "__main__":
test_comments() test_comments()
test_int_literal() test_int_literal()
...@@ -655,3 +894,14 @@ if __name__ == "__main__": ...@@ -655,3 +894,14 @@ if __name__ == "__main__":
test_tensor_type() test_tensor_type()
test_function_type() test_function_type()
test_tuple_type() test_tuple_type()
test_adt_defn()
test_empty_adt_defn()
test_multiple_cons_defn()
test_multiple_type_param_defn()
test_match()
test_adt_cons_expr()
test_duplicate_adt_defn()
test_duplicate_adt_cons()
test_duplicate_adt_cons_defn()
test_duplicate_global_var()
test_extern_adt_defn()
...@@ -23,14 +23,14 @@ from tvm.relay.analysis import alpha_equal, assert_alpha_equal, assert_graph_equ ...@@ -23,14 +23,14 @@ from tvm.relay.analysis import alpha_equal, assert_alpha_equal, assert_graph_equ
do_print = [False] do_print = [False]
SEMVER = "v0.0.3\n" SEMVER = "v0.0.4\n"
def astext(p, graph_equal=False): def astext(p, unify_free_vars=False):
txt = p.astext() txt = p.astext()
if isinstance(p, Expr) and free_vars(p): if isinstance(p, Expr) and free_vars(p):
return txt return txt
x = relay.fromtext(txt) x = relay.fromtext(txt)
if graph_equal: if unify_free_vars:
assert_graph_equal(x, p) assert_graph_equal(x, p)
else: else:
assert_alpha_equal(x, p) assert_alpha_equal(x, p)
...@@ -78,7 +78,7 @@ def test_meta_data(): ...@@ -78,7 +78,7 @@ def test_meta_data():
padding=(1, 1), padding=(1, 1),
channels=2) channels=2)
f = relay.Function([x, w], z) f = relay.Function([x, w], z)
text = astext(f, graph_equal=True) text = astext(f, unify_free_vars=True)
text_no_meta = str(f) text_no_meta = str(f)
assert "channels=2" in text assert "channels=2" in text
assert "channels=2" in text_no_meta assert "channels=2" in text_no_meta
...@@ -122,7 +122,7 @@ def test_let_if_scope(): ...@@ -122,7 +122,7 @@ def test_let_if_scope():
f = relay.Function([x, y, cond], result) f = relay.Function([x, y, cond], result)
text = astext(f) text = astext(f)
assert text.count("{") == 4 assert text.count("{") == 3
assert "%cond: bool" in text assert "%cond: bool" in text
show(astext(f)) show(astext(f))
...@@ -218,14 +218,6 @@ def test_zeros(): ...@@ -218,14 +218,6 @@ def test_zeros():
x = relay.op.zeros([], "float32") x = relay.op.zeros([], "float32")
astext(x) astext(x)
def test_cast():
data = relay.var('data', dtype='float32')
fp16_cast = relay.cast(data, dtype='float16')
cast_func = relay.Function(relay.analysis.free_vars(fp16_cast), fp16_cast)
astext(cast_func)
if __name__ == "__main__": if __name__ == "__main__":
do_print[0] = True do_print[0] = True
test_lstm() test_lstm()
...@@ -247,4 +239,3 @@ if __name__ == "__main__": ...@@ -247,4 +239,3 @@ if __name__ == "__main__":
test_let_if_scope() test_let_if_scope()
test_variable_name() test_variable_name()
test_call_node_order() test_call_node_order()
test_cast()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment