Commit ca0292d8 by Logan Weber Committed by Jared Roesch

[Relay] Add ADTs to text format (#3863)

* Getting closer to having ADT defs

* ADT defs working probly

* Match parsing basipally done

* came to earth in a silver chrome UFO

* match finished?

* All tests but newest are passing

* ADT constructors work

now cleanup?

* Cleanup round 1

* Cleanup round 2

* Cleanup round 3

* Cleanup round 4

* Cleanup round 6

* Cleanup round 7

* Lil grammar fix

* Remove ANTLR Java files

* Lint roller

* Lint roller

* Address feedback

* Test completeness in match test

* Remove unused imports

* Lint roller

* Switch to Rust-style ADT syntax

* Lil fix

* Add dummy `extern type` handler

* Add type arg to test

* Update prelude semantic version

* Repair test

* Fix graph var handling in match

* Revert 's/graph_equal/is_unifiable' change
parent a103c4ee
......@@ -165,6 +165,13 @@ class ModuleNode : public RelayNode {
TVM_DLL TypeData LookupDef(const std::string& var) const;
/*!
* \brief Check if a global type definition exists
* \param var The name of the global type definition.
* \return Whether the definition exists.
*/
TVM_DLL bool HasDef(const std::string& var) const;
/*!
* \brief Look up a constructor by its tag.
* \param tag The tag for the constructor.
* \return The constructor object.
......
......@@ -17,11 +17,15 @@
* under the License.
*/
// list = *, seq = ?
/*
* NOTE: The `USE_ANTLR` option in `config.cmake` must be enabled in order for
* changes in this file to be reflected by the parser.
* NOTE: All upper-case rules are *lexer* rules and all camel-case rules are *parser* rules.
*/
grammar Relay;
SEMVER: 'v0.0.3' ;
SEMVER: 'v0.0.4' ;
// Lexing
// comments
......@@ -49,13 +53,8 @@ BOOL_LIT
| 'False'
;
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)*;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)* ;
DATATYPE : 'int64';
// non-negative floats
fragment PREFLOAT : NAT ('.' NAT)? EXP?; // 1.35, 1.35E-9, 0.3, 4.5, 1, 1e10 3e4
......@@ -74,7 +73,11 @@ METADATA: 'METADATA:' .*;
// A Relay program is a list of global definitions or an expression.
prog: SEMVER (defn* | expr) METADATA? EOF ;
// option: 'set' ident BOOL_LIT ;
// Covers both operator and type idents
generalIdent: CNAME ('.' CNAME)*;
globalVar: '@' CNAME ;
localVar: '%' ('_' | CNAME) ;
graphVar: '%' NAT ;
exprList: (expr (',' expr)*)?;
callList
......@@ -85,7 +88,6 @@ callList
expr
// operators
: '(' expr ')' # paren
| '{' expr '}' # paren
// function application
| expr '(' callList ')' # call
| '-' expr # neg
......@@ -99,53 +101,74 @@ expr
| '(' ')' # tuple
| '(' expr ',' ')' # tuple
| '(' expr (',' expr)+ ')' # tuple
| expr '.' NAT # projection
| '[' (expr (',' expr)*)? ']' # tensor
| 'if' '(' expr ')' body 'else' body # ifElse
| matchType '(' expr ')' '{' matchClauseList? '}' # match
| expr '.' NAT # projection
// sequencing
| 'let' var '=' expr ';' expr # let
// sugar for let %_ = expr; expr
| expr ';;' expr # let
| GRAPH_VAR '=' expr ';' expr # graph
| graphVar '=' expr ';' expr # graph
| ident # identExpr
| scalar # scalarExpr
| meta # metaExpr
| QUOTED_STRING # stringExpr
;
func: 'fn' typeParamList? '(' argList ')' ('->' type_)? body ;
defn: 'def' ident typeParamList? '(' argList ')' ('->' type_)? body ;
func: 'fn' typeParamList? '(' argList ')' ('->' typeExpr)? body ;
defn
: 'def' globalVar typeParamList? '(' argList ')' ('->' typeExpr)? body # funcDefn
| 'extern' 'type' generalIdent typeParamList? # externAdtDefn
| 'type' generalIdent typeParamList? '{' adtConsDefnList? '}' # adtDefn
;
constructorName: CNAME ;
adtConsDefnList: adtConsDefn (',' adtConsDefn)* ','? ;
adtConsDefn: constructorName ('(' typeExpr (',' typeExpr)* ')')? ;
matchClauseList: matchClause (',' matchClause)* ','? ;
matchClause: constructorName patternList? '=>' ('{' expr '}' | expr) ;
// complete or incomplete match, respectively
matchType : 'match' | 'match?' ;
patternList: '(' pattern (',' pattern)* ')';
pattern
: '_'
| localVar (':' typeExpr)?
;
adtCons: constructorName adtConsParamList? ;
adtConsParamList: '(' adtConsParam (',' adtConsParam)* ')' ;
adtConsParam: localVar | constructorName ;
argList
: varList # argNoAttr
| (var ',')* attrSeq # argWithAttr
;
varList: (var (',' var)*)?;
var: LOCAL_VAR (':' type_)?;
varList: (var (',' var)*)? ;
var: localVar (':' typeExpr)? ;
attrSeq: attr (',' attr)*;
attrSeq: attr (',' attr)* ;
attr: CNAME '=' expr ;
typeParamList
: '[' ']'
| '[' ident (',' ident)* ']'
;
type_
typeExpr
: '(' ')' # tupleType
| '(' type_ ',' ')' # tupleType
| '(' type_ (',' type_)+ ')' # tupleType
| typeIdent # typeIdentType
| 'Tensor' '[' shapeList ',' type_ ']' # tensorType
| 'fn' typeParamList? '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
| '(' typeExpr ',' ')' # tupleType
| '(' typeExpr (',' typeExpr)+ ')' # tupleType
| generalIdent typeParamList # typeCallType
| generalIdent # typeIdentType
| 'Tensor' '[' shapeList ',' typeExpr ']' # tensorType
| 'fn' typeParamList? '(' (typeExpr (',' typeExpr)*)? ')' '->' typeExpr # funcType
| '_' # incompleteType
| NAT # intType
;
typeParamList: '[' generalIdent (',' generalIdent)* ']' ;
shapeList
: '(' shape (',' shape)+ ')'
| '(' ')'
: '(' ')'
| '(' shape (',' shape)+ ')'
| shape
;
......@@ -157,12 +180,6 @@ shape
| NAT # intShape
;
typeIdent : CNAME;
// int8, int16, int32, int64
// uint8, uint16, uint32, uint64
// float16, float32, float64
// bool
body: '{' expr '}' ;
scalar
......@@ -172,8 +189,8 @@ scalar
;
ident
: opIdent
| GLOBAL_VAR
| LOCAL_VAR
| GRAPH_VAR
: generalIdent
| globalVar
| localVar
| graphVar
;
This source diff could not be displayed because it is too large. You can view the blob instead.
# Generated from /workspace/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.1
# Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2
from antlr4 import *
if __name__ is not None and "." in __name__:
from .RelayParser import RelayParser
......@@ -9,13 +9,28 @@ else:
class RelayVisitor(ParseTreeVisitor):
# Visit a parse tree produced by RelayParser#opIdent.
def visitOpIdent(self, ctx:RelayParser.OpIdentContext):
# Visit a parse tree produced by RelayParser#prog.
def visitProg(self, ctx:RelayParser.ProgContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#prog.
def visitProg(self, ctx:RelayParser.ProgContext):
# Visit a parse tree produced by RelayParser#generalIdent.
def visitGeneralIdent(self, ctx:RelayParser.GeneralIdentContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#globalVar.
def visitGlobalVar(self, ctx:RelayParser.GlobalVarContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#localVar.
def visitLocalVar(self, ctx:RelayParser.LocalVarContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#graphVar.
def visitGraphVar(self, ctx:RelayParser.GraphVarContext):
return self.visitChildren(ctx)
......@@ -44,6 +59,11 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#match.
def visitMatch(self, ctx:RelayParser.MatchContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#tensor.
def visitTensor(self, ctx:RelayParser.TensorContext):
return self.visitChildren(ctx)
......@@ -114,8 +134,73 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#defn.
def visitDefn(self, ctx:RelayParser.DefnContext):
# Visit a parse tree produced by RelayParser#funcDefn.
def visitFuncDefn(self, ctx:RelayParser.FuncDefnContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#externAdtDefn.
def visitExternAdtDefn(self, ctx:RelayParser.ExternAdtDefnContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtDefn.
def visitAdtDefn(self, ctx:RelayParser.AdtDefnContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#constructorName.
def visitConstructorName(self, ctx:RelayParser.ConstructorNameContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtConsDefnList.
def visitAdtConsDefnList(self, ctx:RelayParser.AdtConsDefnListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtConsDefn.
def visitAdtConsDefn(self, ctx:RelayParser.AdtConsDefnContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#matchClauseList.
def visitMatchClauseList(self, ctx:RelayParser.MatchClauseListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#matchClause.
def visitMatchClause(self, ctx:RelayParser.MatchClauseContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#matchType.
def visitMatchType(self, ctx:RelayParser.MatchTypeContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#patternList.
def visitPatternList(self, ctx:RelayParser.PatternListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#pattern.
def visitPattern(self, ctx:RelayParser.PatternContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtCons.
def visitAdtCons(self, ctx:RelayParser.AdtConsContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtConsParamList.
def visitAdtConsParamList(self, ctx:RelayParser.AdtConsParamListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#adtConsParam.
def visitAdtConsParam(self, ctx:RelayParser.AdtConsParamContext):
return self.visitChildren(ctx)
......@@ -149,13 +234,13 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#typeParamList.
def visitTypeParamList(self, ctx:RelayParser.TypeParamListContext):
# Visit a parse tree produced by RelayParser#tupleType.
def visitTupleType(self, ctx:RelayParser.TupleTypeContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#tupleType.
def visitTupleType(self, ctx:RelayParser.TupleTypeContext):
# Visit a parse tree produced by RelayParser#typeCallType.
def visitTypeCallType(self, ctx:RelayParser.TypeCallTypeContext):
return self.visitChildren(ctx)
......@@ -179,8 +264,8 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#intType.
def visitIntType(self, ctx:RelayParser.IntTypeContext):
# Visit a parse tree produced by RelayParser#typeParamList.
def visitTypeParamList(self, ctx:RelayParser.TypeParamListContext):
return self.visitChildren(ctx)
......@@ -209,11 +294,6 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#typeIdent.
def visitTypeIdent(self, ctx:RelayParser.TypeIdentContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#body.
def visitBody(self, ctx:RelayParser.BodyContext):
return self.visitChildren(ctx)
......
......@@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
v0.0.3
v0.0.4
def @id[a](%x: a) -> a {
%x
......
......@@ -70,7 +70,10 @@ class AlphaEqualHandler:
}
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) {
if (!Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) return false;
if (!rhsm->HasDef(p.first->var->name_hint) ||
!Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) {
return false;
}
}
return true;
}
......@@ -288,7 +291,7 @@ class AlphaEqualHandler:
}
bool VisitType_(const GlobalTypeVarNode* lhs, const Type& other) final {
return GetRef<Type>(lhs) == other;
return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
}
bool VisitType_(const TypeCallNode* lhs, const Type& other) final {
......@@ -307,6 +310,26 @@ class AlphaEqualHandler:
return true;
}
bool VisitType_(const TypeDataNode* lhs, const Type& other) final {
const TypeDataNode* rhs = other.as<TypeDataNode>();
if (rhs == nullptr
|| lhs->type_vars.size() != rhs->type_vars.size()
|| !TypeEqual(lhs->header, rhs->header)) {
return false;
}
for (size_t i = 0; i < lhs->type_vars.size(); ++i) {
if (!TypeEqual(lhs->type_vars[i], rhs->type_vars[i])) {
return false;
}
}
for (size_t i = 0; i < lhs->constructors.size(); ++i) {
if (!ExprEqual(lhs->constructors[i], rhs->constructors[i])) {
return false;
}
}
return true;
}
// Expr equal checking.
bool NDArrayEqual(const runtime::NDArray& lhs,
const runtime::NDArray& rhs) {
......@@ -485,7 +508,10 @@ class AlphaEqualHandler:
}
bool VisitExpr_(const ConstructorNode* lhs, const Expr& other) final {
return GetRef<Expr>(lhs) == other;
if (const ConstructorNode* rhs = other.as<ConstructorNode>()) {
return lhs->name_hint == rhs->name_hint;
}
return false;
}
bool ClauseEqual(const Clause& lhs, const Clause& rhs) {
......@@ -582,7 +608,7 @@ TVM_REGISTER_API("relay._make._alpha_equal")
TVM_REGISTER_API("relay._make._assert_alpha_equal")
.set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " is not alpha equal";
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal";
});
TVM_REGISTER_API("relay._make._graph_equal")
......@@ -593,7 +619,7 @@ TVM_REGISTER_API("relay._make._graph_equal")
TVM_REGISTER_API("relay._make._assert_graph_equal")
.set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " is not graph equal";
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal";
});
} // namespace relay
......
......@@ -206,6 +206,11 @@ TypeData ModuleNode::LookupDef(const std::string& name) const {
return this->LookupDef(id);
}
bool ModuleNode::HasDef(const std::string& name) const {
auto it = global_type_var_map_.find(name);
return it != global_type_var_map_.end();
}
Constructor ModuleNode::LookupTag(const int32_t tag) {
auto it = constructor_tag_map_.find(tag);
CHECK(it != constructor_tag_map_.end())
......
......@@ -44,6 +44,8 @@
namespace tvm {
namespace relay {
static const char* kSemVer = "v0.0.4";
Doc Brace(const Doc& d,
const std::string& open = "{",
const std::string& close = "}",
......@@ -239,6 +241,8 @@ class PrettyPrinter :
return PrintExpr(Downcast<Expr>(node), meta, try_inline);
} else if (node.as_derived<TypeNode>()) {
return PrintType(Downcast<Type>(node), meta);
} else if (node.as_derived<PatternNode>()) {
return PrintPattern(Downcast<Pattern>(node), meta);
} else if (node.as_derived<ModuleNode>()) {
return PrintMod(Downcast<Module>(node));
} else {
......@@ -313,7 +317,7 @@ class PrettyPrinter :
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name;
}
Doc val = GetUniqueName("%" + name);
Doc val = GetUniqueName(name);
memo_type_[var] = val;
if (var->kind != kType) {
val << ": " << Print(var->kind);
......@@ -347,13 +351,17 @@ class PrettyPrinter :
}
bool IsUnique(const Expr& expr) {
return !(dg_.expr_node.at(expr)->parents.head &&
dg_.expr_node.at(expr)->parents.head->next);
auto it = dg_.expr_node.find(expr);
if (it == dg_.expr_node.end()) {
return true;
} else {
return !(it->second->parents.head && it->second->parents.head->next);
}
}
bool AlwaysInline(const Expr& expr) {
return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() ||
expr.as<OpNode>() || expr.as<VarNode>();
return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || expr.as<OpNode>() ||
expr.as<VarNode>() || expr.as<ConstructorNode>();
}
//------------------------------------
......@@ -380,9 +388,9 @@ class PrettyPrinter :
} else if (!inline_expr && expr.as<LetNode>()) {
// wrap GNFed let in brackets
Doc body;
printed_expr << "{";
printed_expr << "(";
printed_expr << Indent(2, body << PrintNewLine() << VisitExpr(expr)) << PrintNewLine();
printed_expr << "}";
printed_expr << ")";
} else {
printed_expr = VisitExpr(expr);
}
......@@ -483,13 +491,13 @@ class PrettyPrinter :
Doc doc;
doc << prefix;
if (fn->type_params.size() > 0) {
doc << "<";
doc << "[";
std::vector<Doc> type_params;
for (const TypeVar& tv : fn->type_params) {
type_params.push_back(AllocTypeVar(tv));
type_params.push_back(Doc(tv->var->name_hint));
}
doc << PrintSep(type_params);
doc << ">";
doc << "]";
}
doc << "(";
std::vector<Doc> params;
......@@ -510,6 +518,15 @@ class PrettyPrinter :
Doc PrintMod(const Module& mod) {
Doc doc;
int counter = 0;
// type definitions
for (const auto& kv : mod->type_definitions) {
if (counter++ != 0) {
doc << PrintNewLine();
}
doc << Print(kv.second);
doc << PrintNewLine();
}
// functions
for (const auto& kv : mod->functions) {
dg_ = DependencyGraph::Create(&arena_, kv.second);
......@@ -547,7 +564,12 @@ class PrettyPrinter :
for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
args.push_back(d);
}
const auto* cons_node = op->op.as<ConstructorNode>();
if (cons_node) {
doc << cons_node->name_hint;
} else {
doc << Print(op->op);
}
return doc << "(" << PrintSep(args) << ")";
}
......@@ -570,27 +592,57 @@ class PrettyPrinter :
// TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs.
Doc doc;
Doc body;
doc << "match " << Print(op->data) << " ";
doc << "{";
std::vector<Doc> clauses;
doc << "match";
if (!op->complete) {
doc << "?";
}
doc << " (" << Print(op->data) << ") {";
std::vector<Doc> clause_docs;
for (const auto& clause : op->clauses) {
Doc clause_doc;
clauses.push_back(clause_doc << Print(clause->lhs) << " -> "
<< Print(clause->rhs));
clause_doc << PrintPattern(clause->lhs, false) << " => ";
Doc rhs_doc = PrintScope(clause->rhs);
if (clause->rhs.as<LetNode>()) {
// only add braces if there are multiple lines on the rhs
rhs_doc = Brace(rhs_doc);
}
doc << Indent(2, body << PrintNewLine() << PrintSep(clauses, PrintNewLine())) << PrintNewLine();
doc << "}";
clause_doc << rhs_doc << ",";
clause_docs.push_back(clause_doc);
}
doc << Indent(2, body << PrintNewLine() << PrintSep(clause_docs, PrintNewLine()))
<< PrintNewLine() << "}";
return doc;
}
Doc PrintPattern(const Pattern& pattern, bool meta) {
auto it = memo_pattern_.find(pattern);
if (it != memo_pattern_.end()) return it->second;
Doc printed_pattern;
if (meta) {
printed_pattern = meta_.GetMetaNode(GetRef<NodeRef>(pattern.get()));
} else {
printed_pattern = VisitPattern(pattern);
}
memo_pattern_[pattern] = printed_pattern;
return printed_pattern;
}
Doc VisitPattern_(const PatternConstructorNode* p) final {
Doc doc;
doc << p->constructor->name_hint << "(";
doc << p->constructor->name_hint;
if (!p->patterns.empty()) {
doc << "(";
std::vector<Doc> pats;
for (const auto& pat : p->patterns) {
pats.push_back(Print(pat));
}
return doc << PrintSep(pats) << ")";
doc << PrintSep(pats) << ")";
}
return doc;
}
Doc VisitPattern_(const PatternWildcardNode* pw) final {
return Doc("_");
}
Doc VisitPattern_(const PatternVarNode* pv) final {
......@@ -598,7 +650,17 @@ class PrettyPrinter :
}
Doc VisitExpr_(const ConstructorNode* n) final {
return Doc(n->name_hint);
Doc doc;
doc << n->name_hint;
if (n->inputs.size() != 0) {
doc << "(";
std::vector<Doc> inputs;
for (Type input : n->inputs) {
inputs.push_back(Print(input));
}
doc << PrintSep(inputs) << ")";
}
return doc;
}
//------------------------------------
......@@ -623,7 +685,7 @@ class PrettyPrinter :
}
Doc VisitType_(const TypeVarNode* node) final {
return AllocTypeVar(GetRef<TypeVar>(node));
return Doc(node->var->name_hint);
}
Doc VisitType_(const GlobalTypeVarNode* node) final {
......@@ -675,13 +737,13 @@ class PrettyPrinter :
Doc doc;
doc << "fn ";
if (node->type_params.size() != 0) {
doc << "<";
doc << "[";
std::vector<Doc> type_params;
for (Type type_param : node->type_params) {
type_params.push_back(Print(type_param));
}
doc << PrintSep(type_params);
doc << ">";
doc << "]";
}
std::vector<Doc> arg_types;
for (Type arg_type : node->arg_types) {
......@@ -695,6 +757,37 @@ class PrettyPrinter :
return doc << "ref(" << Print(node->value) << ")";
}
Doc VisitType_(const TypeDataNode* node) final {
Doc doc;
doc << "type " << Print(node->header);
// type vars
if (node->type_vars.size() != 0) {
doc << "[";
std::vector<Doc> type_vars;
for (Type type_var : node->type_vars) {
type_vars.push_back(Print(type_var));
}
doc << PrintSep(type_vars) << "]";
}
doc << " ";
std::vector<Doc> constructor_docs;
for (Constructor constructor : node->constructors) {
constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true));
}
Doc separator;
separator << "," << PrintNewLine();
Doc adt_body;
adt_body << PrintSep(constructor_docs, separator);
// add trailing comma if there are any constructors
if (!constructor_docs.empty()) {
adt_body << ",";
}
doc << Brace(adt_body);
return doc;
}
//------------------------------------
// Overload of Attr printing functions
//------------------------------------
......@@ -758,6 +851,8 @@ class PrettyPrinter :
std::unordered_map<Expr, Doc, NodeHash, NodeEqual> memo_;
/*! \brief Map from Type to Doc */
std::unordered_map<Type, Doc, NodeHash, NodeEqual> memo_type_;
/*! \brief Map from Type to Doc */
std::unordered_map<Pattern, Doc, NodeHash, NodeEqual> memo_pattern_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief meta data context */
......@@ -861,7 +956,7 @@ std::string PrettyPrint_(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
Doc doc;
doc << "v0.0.3" << PrintNewLine()
doc << kSemVer << PrintNewLine()
<< PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
return doc.str();
}
......
......@@ -774,7 +774,6 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
bool update_missing_type_annotation_{true};
};
Expr TypeInferencer::Infer(Expr expr) {
// Step 1: Populate the constraints.
GetType(expr);
......
......@@ -16,14 +16,14 @@
# under the License.
import tvm
from tvm import relay
from tvm.relay.analysis import alpha_equal, assert_alpha_equal
from tvm.relay.analysis import graph_equal, assert_graph_equal
from nose.tools import nottest, raises
from numpy import isclose
from typing import Union
from functools import wraps
raises_parse_error = raises(tvm._ffi.base.TVMError)
SEMVER = "v0.0.3"
SEMVER = "v0.0.4"
BINARY_OPS = {
"*": relay.multiply,
......@@ -60,20 +60,29 @@ TYPES = {
"float16x4",
}
LIST_DEFN = """
type List[A] {
Cons(A, List[A]),
Nil,
}
"""
def roundtrip(expr):
x = relay.fromtext(str(expr))
assert_alpha_equal(x, expr)
assert_graph_equal(x, expr)
def parse_text(code):
x = relay.fromtext(SEMVER + "\n" + code)
roundtrip(x)
return x
expr = relay.fromtext(SEMVER + "\n" + code)
roundtrip(expr)
return expr
def parses_as(code, expr):
# type: (str, relay.Expr) -> bool
return alpha_equal(parse_text(code), expr)
parsed = parse_text(code)
result = graph_equal(parsed, expr)
return result
def get_scalar(x):
# type: (relay.Constant) -> (Union[float, int, bool])
......@@ -168,13 +177,13 @@ def test_bin_op():
def test_parens():
assert alpha_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1"))
assert not alpha_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)"))
assert graph_equal(parse_text("1 * 1 + 1"), parse_text("(1 * 1) + 1"))
assert not graph_equal(parse_text("1 * 1 + 1"), parse_text("1 * (1 + 1)"))
def test_op_assoc():
assert alpha_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1"))
assert alpha_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))"))
assert graph_equal(parse_text("1 * 1 + 1 < 1 == 1"), parse_text("(((1 * 1) + 1) < 1) == 1"))
assert graph_equal(parse_text("1 == 1 < 1 + 1 * 1"), parse_text("1 == (1 < (1 + (1 * 1)))"))
@nottest
......@@ -239,7 +248,7 @@ def test_seq():
)
assert parses_as(
"let %_ = { 1 }; ()",
"let %_ = 1; ()",
relay.Let(
X,
relay.const(1),
......@@ -249,13 +258,13 @@ def test_seq():
def test_graph():
code = "%0 = (); %1 = 1; (%0, %0, %1)"
assert parses_as(
"%0 = (); %1 = 1; (%0, %0, %1)",
code,
relay.Tuple([UNIT, UNIT, relay.const(1)])
)
assert not parses_as(
"%0 = (); %1 = 1; (%0, %0, %1)",
code,
relay.Tuple([relay.Tuple([]), relay.Tuple([]), relay.const(1)])
)
......@@ -632,6 +641,236 @@ def test_tuple_type():
)
)
def test_adt_defn():
mod = relay.Module()
glob_typ_var = relay.GlobalTypeVar("Ayy")
prog = relay.TypeData(
glob_typ_var,
[],
[relay.Constructor("Nil", [], glob_typ_var)])
mod[glob_typ_var] = prog
assert parses_as(
"""
type Ayy { Nil }
""",
mod
)
def test_empty_adt_defn():
mod = relay.Module()
glob_typ_var = relay.GlobalTypeVar("Ayy")
prog = relay.TypeData(glob_typ_var, [], [])
mod[glob_typ_var] = prog
assert parses_as(
"""
type Ayy { }
""",
mod
)
def test_multiple_cons_defn():
mod = relay.Module()
list_var = relay.GlobalTypeVar("List")
typ_var = relay.TypeVar("A")
prog = relay.TypeData(
list_var,
[typ_var],
[
relay.Constructor("Cons", [typ_var, list_var(typ_var)], list_var),
relay.Constructor("Nil", [], list_var),
])
mod[list_var] = prog
assert parses_as(LIST_DEFN, mod)
def test_multiple_type_param_defn():
glob_typ_var = relay.GlobalTypeVar("Either")
typ_var_a = relay.TypeVar("A")
typ_var_b = relay.TypeVar("B")
prog = relay.TypeData(
glob_typ_var,
[typ_var_a, typ_var_b],
[
relay.Constructor("Left", [typ_var_a], glob_typ_var),
relay.Constructor("Right", [typ_var_b], glob_typ_var),
])
mod = relay.Module()
mod[glob_typ_var] = prog
assert parses_as(
"""
type Either[A, B] {
Left(A),
Right(B),
}
""",
mod
)
def test_match():
# pair each match keyword with whether it specifies a complete match or not
match_keywords = [("match", True), ("match?", False)]
for (match_keyword, is_complete) in match_keywords:
mod = relay.Module()
list_var = relay.GlobalTypeVar("List")
typ_var = relay.TypeVar("A")
cons_constructor = relay.Constructor(
"Cons", [typ_var, list_var(typ_var)], list_var)
nil_constructor = relay.Constructor("Nil", [], list_var)
list_def = relay.TypeData(
list_var,
[typ_var],
[cons_constructor, nil_constructor])
mod[list_var] = list_def
length_var = relay.GlobalVar("length")
typ_var = relay.TypeVar("A")
input_type = list_var(typ_var)
input_var = relay.Var("xs", input_type)
rest_var = relay.Var("rest")
cons_case = relay.Let(
_,
UNIT,
relay.add(relay.const(1), relay.Call(length_var, [rest_var])))
body = relay.Match(input_var,
[relay.Clause(
relay.PatternConstructor(
cons_constructor,
[relay.PatternWildcard(), relay.PatternVar(rest_var)]),
cons_case),
relay.Clause(
relay.PatternConstructor(nil_constructor, []),
relay.const(0))],
complete=is_complete
)
length_func = relay.Function(
[input_var],
body,
int32,
[typ_var]
)
mod[length_var] = length_func
assert parses_as(
"""
%s
def @length[A](%%xs: List[A]) -> int32 {
%s (%%xs) {
Cons(_, %%rest) => {
();;
1 + @length(%%rest)
},
Nil => 0,
}
}
""" % (LIST_DEFN, match_keyword),
mod
)
def test_adt_cons_expr():
mod = relay.Module()
list_var = relay.GlobalTypeVar("List")
typ_var = relay.TypeVar("A")
cons_constructor = relay.Constructor(
"Cons", [typ_var, list_var(typ_var)], list_var)
nil_constructor = relay.Constructor("Nil", [], list_var)
list_def = relay.TypeData(
list_var,
[typ_var],
[cons_constructor, nil_constructor])
mod[list_var] = list_def
make_singleton_var = relay.GlobalVar("make_singleton")
input_var = relay.Var("x", int32)
make_singleton_func = relay.Function(
[input_var],
cons_constructor(input_var, nil_constructor()),
list_var(int32)
)
mod[make_singleton_var] = make_singleton_func
assert parses_as(
"""
%s
def @make_singleton(%%x: int32) -> List[int32] {
Cons(%%x, Nil())
}
""" % LIST_DEFN,
mod
)
@raises_parse_error
def test_duplicate_adt_defn():
parse_text(
"""
%s
type List[A] {
Cons(A, List[A]),
Nil,
}
""" % LIST_DEFN
)
@raises_parse_error
def test_duplicate_adt_cons():
parse_text(
"""
type Ayy { Lmao }
type Haha { Lmao }
"""
)
@raises_parse_error
def test_duplicate_adt_cons_defn():
parse_text(
"""
type Ayy { Lmao }
type Lmao { Ayy }
"""
)
@raises_parse_error
def test_duplicate_global_var():
parse_text(
"""
def @id[A](%x: A) -> A { x }
def @id[A](%x: A) -> A { x }
"""
)
def test_extern_adt_defn():
# TODO(weberlo): update this test once extern is implemented
mod = relay.Module()
extern_var = relay.GlobalTypeVar("T")
typ_var = relay.TypeVar("A")
extern_def = relay.TypeData(extern_var, [typ_var], [])
mod[extern_var] = extern_def
assert parses_as(
"""
extern type T[A]
""",
mod
)
if __name__ == "__main__":
test_comments()
test_int_literal()
......@@ -655,3 +894,14 @@ if __name__ == "__main__":
test_tensor_type()
test_function_type()
test_tuple_type()
test_adt_defn()
test_empty_adt_defn()
test_multiple_cons_defn()
test_multiple_type_param_defn()
test_match()
test_adt_cons_expr()
test_duplicate_adt_defn()
test_duplicate_adt_cons()
test_duplicate_adt_cons_defn()
test_duplicate_global_var()
test_extern_adt_defn()
......@@ -23,14 +23,14 @@ from tvm.relay.analysis import alpha_equal, assert_alpha_equal, assert_graph_equ
do_print = [False]
SEMVER = "v0.0.3\n"
SEMVER = "v0.0.4\n"
def astext(p, graph_equal=False):
def astext(p, unify_free_vars=False):
txt = p.astext()
if isinstance(p, Expr) and free_vars(p):
return txt
x = relay.fromtext(txt)
if graph_equal:
if unify_free_vars:
assert_graph_equal(x, p)
else:
assert_alpha_equal(x, p)
......@@ -78,7 +78,7 @@ def test_meta_data():
padding=(1, 1),
channels=2)
f = relay.Function([x, w], z)
text = astext(f, graph_equal=True)
text = astext(f, unify_free_vars=True)
text_no_meta = str(f)
assert "channels=2" in text
assert "channels=2" in text_no_meta
......@@ -122,7 +122,7 @@ def test_let_if_scope():
f = relay.Function([x, y, cond], result)
text = astext(f)
assert text.count("{") == 4
assert text.count("{") == 3
assert "%cond: bool" in text
show(astext(f))
......@@ -218,14 +218,6 @@ def test_zeros():
x = relay.op.zeros([], "float32")
astext(x)
def test_cast():
data = relay.var('data', dtype='float32')
fp16_cast = relay.cast(data, dtype='float16')
cast_func = relay.Function(relay.analysis.free_vars(fp16_cast), fp16_cast)
astext(cast_func)
if __name__ == "__main__":
do_print[0] = True
test_lstm()
......@@ -247,4 +239,3 @@ if __name__ == "__main__":
test_let_if_scope()
test_variable_name()
test_call_node_order()
test_cast()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment