Commit 2ae3124f by Steven S. Lyubomirsky Committed by ziheng

[Relay] Algebraic data types (#2442)

* First pass on ADTs

* Add doc string for tag field

* Visit constructors in TypeVisitor for TypeData

* Add to description of type call

* Add type call to type solving and unification

* Make type mutator for typecall consistent with others (only create new node if there's a change)

* Ensure kindchecking can handle type calls and typedata

* Fix bad nesting in module constructor

* Correctly construct call in typecall test

* Add call override for ordinary vars (do we want this?)

* Remove generalization hack from type inference because it was breaking ADT constructors

* Check that there are no free type vars in exprs after inferring type

* Free var checks need module because of ADT constructors

* Typecall test can't have unbound type var, make it global

* Uncomment tmap test and remove comments about failing to infer ret type; those work now

* Put in dummy visits for ADTs in graph runtime codegen to placate pylint

* Fix Relay type infer test module constructor

* Mark override for TypeCallNode in type solver

* Ensure free vars check treats patern vars as bound

* Run interpreter in more ADT test cases

* Refactor kind check to return the kind, like typechecking

* Fix invalid typecall in test

* Add kind check to type inference, do not use nulls in func_type_annotation()!

* Redundant whitespace

* Make TypeData a separate kind

* Make ADT handles a separate kind too, document calling convention better

* Remove nats and tree from prelude, move to test, document prelude

* Restore and document nat and tree to prelude, add more tree tests

* Add alpha equality tests for match cases, fix variable binding bug

* Add more kind check tests for ADTs

* Add more tests for finding free or bound vars in match exprs

* Add unification tests for type call

* Update main() for alpha equality tests

* Add simple type inference test cases for match exprs and ADT constructors

* Add more ADT interpreter tests

* Allow incomplete types when typechecking match cases

* Type inference for pattern vars should use the type annotation if it's there

* Two more specific test cases for ADT matching

* Add option ADT to prelude

* Fix broken reference to kind enum

* Fix rebase snags

* Do not attach checked types to constructors

* More docstrings for module fields

* Use proper wrapper for indexing into module type data

* checked_type for constructors is not populated

* Expand type call docstring

* Rename PatternConstructor con field

* Use error reporter for pattern constructor case

* Condense error reporting in kind check, use error reporter

* Expand docstrings and rename ADT fields

* Rename 'option' ADT to 'optional' for consistency with Python

* Add various list iterators and utility functions to prelude

* Add smoke tests for new iterators in prelude

* Add concat to prelude

* Add smoke test for concat

* Correct docstrings in prelude

* Ensure that type defs are written in module initialization

* Various requested renamings

* Correct rebase snags

* Add kind check tests for ref types

* Update the main() for kind checking tests
parent c7165427
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/adt.h
* \brief Algebraic data types for Relay
*/
#ifndef TVM_RELAY_ADT_H_
#define TVM_RELAY_ADT_H_
#include <tvm/attrs.h>
#include <string>
#include <functional>
#include "./base.h"
#include "./type.h"
#include "./expr.h"
namespace tvm {
namespace relay {
/*! \brief Base type for declaring relay pattern. */
class PatternNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Pattern";
TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node);
};
/*!
* \brief Pattern is the base type for an ADT match pattern in Relay.
*
* Given an ADT value, a pattern might accept it and bind the pattern variable to some value
* (typically a subnode of the input or the input). Otherwise, the pattern rejects the value.
*
* ADT pattern matching thus takes a list of values and binds to the first that accepts the value.
*/
class Pattern : public NodeRef {
public:
Pattern() {}
explicit Pattern(NodePtr<tvm::Node> p) : NodeRef(p) {}
using ContainerType = PatternNode;
};
/*! \brief A wildcard pattern: Accepts all input and binds nothing. */
class PatternWildcard;
/*! \brief PatternWildcard container node */
class PatternWildcardNode : public PatternNode {
public:
PatternWildcardNode() {}
TVM_DLL static PatternWildcard make();
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.PatternWildcard";
TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode);
};
RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern);
/*! \brief A var pattern. Accept all input and bind to a var. */
class PatternVar;
/*! \brief PatternVar container node */
class PatternVarNode : public PatternNode {
public:
PatternVarNode() {}
/*! \brief Variable that stores the matched value. */
tvm::relay::Var var;
TVM_DLL static PatternVar make(tvm::relay::Var var);
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.PatternVar";
TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode);
};
RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern);
/*!
* \brief ADT constructor.
* Constructors compare by pointer equality.
*/
class Constructor;
/*! \brief Constructor container node. */
class ConstructorNode : public ExprNode {
public:
/*! \brief The name (only a hint) */
std::string name_hint;
/*! \brief Input to the constructor. */
tvm::Array<Type> inputs;
/*! \brief The datatype the constructor will construct. */
GlobalTypeVar belong_to;
/*! \brief Index in the table of constructors (set when the type is registered). */
mutable int tag = -1;
ConstructorNode() {}
TVM_DLL static Constructor make(std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to);
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name_hint", &name_hint);
v->Visit("inputs", &inputs);
v->Visit("belong_to", &belong_to);
v->Visit("tag", &tag);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr);
/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */
class PatternConstructor;
/*! \brief PatternVar container node */
class PatternConstructorNode : public PatternNode {
public:
/*! Constructor matched by the pattern. */
Constructor constructor;
/*! Sub-patterns to match against each input to the constructor. */
tvm::Array<Pattern> patterns;
PatternConstructorNode() {}
TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array<Pattern> var);
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("constructor", &constructor);
v->Visit("patterns", &patterns);
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.PatternConstructor";
TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode);
};
RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern);
/*!
* \brief Stores all data for an Algebraic Data Type (ADT).
*
* In particular, it stores the handle (global type var) for an ADT
* and the constructors used to build it and is kept in the module. Note
* that type parameters are also indicated in the type data: this means that
* for any instance of an ADT, the type parameters must be indicated. That is,
* an ADT definition is treated as a type-level function, so an ADT handle
* must be wrapped in a TypeCall node that instantiates the type-level arguments.
* The kind checker enforces this.
*/
class TypeData;
/*! \brief TypeData container node */
class TypeDataNode : public TypeNode {
public:
/*!
* \brief The header is simply the name of the ADT.
* We adopt nominal typing for ADT definitions;
* that is, differently-named ADT definitions with same constructors
* have different types.
*/
GlobalTypeVar header;
/*! \brief The type variables (to allow for polymorphism). */
tvm::Array<TypeVar> type_vars;
/*! \brief The constructors. */
tvm::Array<Constructor> constructors;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("header", &header);
v->Visit("type_vars", &type_vars);
v->Visit("constructors", &constructors);
v->Visit("span", &span);
}
TVM_DLL static TypeData make(GlobalTypeVar header,
tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors);
static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type);
/*! \brief A clause in a match expression. */
class Clause;
/*! \brief Clause container node. */
class ClauseNode : public Node {
public:
/*! \brief The pattern the clause matches. */
Pattern lhs;
/*! \brief The resulting value. */
Expr rhs;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("lhs", &lhs);
v->Visit("rhs", &rhs);
}
TVM_DLL static Clause make(Pattern lhs, Expr rhs);
static constexpr const char* _type_key = "relay.Clause";
TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node);
};
RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef);
/*! \brief ADT pattern matching exression. */
class Match;
/*! \brief Match container node. */
class MatchNode : public ExprNode {
public:
/*! \brief The input being deconstructed. */
Expr data;
/*! \brief The match node clauses. */
tvm::Array<Clause> clauses;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data);
v->Visit("clause", &clauses);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern);
static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ADT_H_
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <tvm/node/ir_functor.h> #include <tvm/node/ir_functor.h>
#include <string> #include <string>
#include "./expr.h" #include "./expr.h"
#include "./adt.h"
#include "./op.h" #include "./op.h"
#include "./error.h" #include "./error.h"
...@@ -92,6 +93,8 @@ class ExprFunctor<R(const Expr& n, Args...)> { ...@@ -92,6 +93,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) { virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key()); throw Error(std::string("Do not have a default for ") + op->type_key());
} }
...@@ -114,6 +117,8 @@ class ExprFunctor<R(const Expr& n, Args...)> { ...@@ -114,6 +117,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode); RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode); RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode); RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode);
RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode);
return vtable; return vtable;
} }
}; };
...@@ -142,7 +147,11 @@ class ExprVisitor ...@@ -142,7 +147,11 @@ class ExprVisitor
void VisitExpr_(const RefCreateNode* op) override; void VisitExpr_(const RefCreateNode* op) override;
void VisitExpr_(const RefReadNode* op) override; void VisitExpr_(const RefReadNode* op) override;
void VisitExpr_(const RefWriteNode* op) override; void VisitExpr_(const RefWriteNode* op) override;
void VisitExpr_(const ConstructorNode* op) override;
void VisitExpr_(const MatchNode* op) override;
virtual void VisitType(const Type& t); virtual void VisitType(const Type& t);
virtual void VisitClause(const Clause& c);
virtual void VisitPattern(const Pattern& c);
protected: protected:
// Internal visiting counter // Internal visiting counter
...@@ -180,6 +189,9 @@ class ExprMutator ...@@ -180,6 +189,9 @@ class ExprMutator
Expr VisitExpr_(const RefCreateNode* op) override; Expr VisitExpr_(const RefCreateNode* op) override;
Expr VisitExpr_(const RefReadNode* op) override; Expr VisitExpr_(const RefReadNode* op) override;
Expr VisitExpr_(const RefWriteNode* op) override; Expr VisitExpr_(const RefWriteNode* op) override;
Expr VisitExpr_(const ConstructorNode* op) override;
Expr VisitExpr_(const MatchNode* op) override;
/*! /*!
* \brief Used to visit the types inside of expressions. * \brief Used to visit the types inside of expressions.
* *
...@@ -188,6 +200,8 @@ class ExprMutator ...@@ -188,6 +200,8 @@ class ExprMutator
* visitor for types which transform them appropriately. * visitor for types which transform them appropriately.
*/ */
virtual Type VisitType(const Type& t); virtual Type VisitType(const Type& t);
virtual Clause VisitClause(const Clause& c);
virtual Pattern VisitPattern(const Pattern& c);
protected: protected:
/*! \brief Internal map used for memoization. */ /*! \brief Internal map used for memoization. */
......
...@@ -160,6 +160,28 @@ struct RefValueNode : ValueNode { ...@@ -160,6 +160,28 @@ struct RefValueNode : ValueNode {
RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value); RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
/*! \brief An ADT constructor value. */
class ConstructorValue;
struct ConstructorValueNode : ValueNode {
Constructor constructor;
tvm::Array<Value> fields;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("constructor", &constructor);
v->Visit("fields", &fields);
}
TVM_DLL static ConstructorValue make(Constructor constructor,
tvm::Array<Value> fields);
static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(ConstructorValue, ConstructorValueNode, Value);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_INTERPRETER_H_ #endif // TVM_RELAY_INTERPRETER_H_
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <tvm/relay/error.h> #include <tvm/relay/error.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <string> #include <string>
...@@ -35,13 +36,15 @@ struct Module; ...@@ -35,13 +36,15 @@ struct Module;
* *
* The functional style allows users to construct custom * The functional style allows users to construct custom
* environments easily, for example each thread can store * environments easily, for example each thread can store
* an Module while auto-tuning. * a Module while auto-tuning.
* */ * */
class ModuleNode : public RelayNode { class ModuleNode : public RelayNode {
public: public:
/*! \brief A map from ids to all global functions. */ /*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, Function> functions; tvm::Map<GlobalVar, Function> functions;
/*! \brief A map from global type vars to ADT type data. */
tvm::Map<GlobalTypeVar, TypeData> type_definitions;
/*! \brief The entry function (i.e. "main"). */ /*! \brief The entry function (i.e. "main"). */
GlobalVar entry_func; GlobalVar entry_func;
...@@ -50,15 +53,18 @@ class ModuleNode : public RelayNode { ...@@ -50,15 +53,18 @@ class ModuleNode : public RelayNode {
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions); v->Visit("functions", &functions);
v->Visit("type_definitions", &type_definitions);
v->Visit("global_var_map_", &global_var_map_); v->Visit("global_var_map_", &global_var_map_);
v->Visit("entry_func", &entry_func); v->Visit("entry_func", &entry_func);
v->Visit("global_type_var_map_", &global_type_var_map_);
} }
TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs); TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs);
/*! /*!
* \brief Add a function to the global environment. * \brief Add a function to the global environment.
* \param var The name of the global function. * \param var The var of the global function.
* \param func The function. * \param func The function.
* \param update Controls whether you can replace a definition in the * \param update Controls whether you can replace a definition in the
* environment. * environment.
...@@ -66,6 +72,13 @@ class ModuleNode : public RelayNode { ...@@ -66,6 +72,13 @@ class ModuleNode : public RelayNode {
void Add(const GlobalVar& var, const Function& func, bool update = false); void Add(const GlobalVar& var, const Function& func, bool update = false);
/*! /*!
* \brief Add a type-level definition to the global environment.
* \param var The var of the global type definition.
* \param type The type definition.
*/
void AddDef(const GlobalTypeVar& var, const TypeData& type);
/*!
* \brief Add a function to the global environment. * \brief Add a function to the global environment.
* \param var The name of the global function. * \param var The name of the global function.
* \param func The function. * \param func The function.
...@@ -95,6 +108,13 @@ class ModuleNode : public RelayNode { ...@@ -95,6 +108,13 @@ class ModuleNode : public RelayNode {
GlobalVar GetGlobalVar(const std::string& str); GlobalVar GetGlobalVar(const std::string& str);
/*! /*!
* \brief Look up a global function by its name.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
GlobalTypeVar GetGlobalTypeVar(const std::string& str);
/*!
* \brief Lookup a global function by its variable. * \brief Lookup a global function by its variable.
* \param var The global var to lookup. * \param var The global var to lookup.
* \returns The function named by the variable argument. * \returns The function named by the variable argument.
...@@ -109,6 +129,20 @@ class ModuleNode : public RelayNode { ...@@ -109,6 +129,20 @@ class ModuleNode : public RelayNode {
Function Lookup(const std::string& name); Function Lookup(const std::string& name);
/*! /*!
* \brief Lookup a global type definition by its variable.
* \param var The var of the global type definition.
* \return The type definition.
*/
TypeData LookupDef(const GlobalTypeVar& var);
/*!
* \brief Lookup a global type definition by its name.
* \param var The name of the global type definition.
* \return The type definition.
*/
TypeData LookupDef(const std::string& var);
/*!
* \brief Update the functions inside this environment by * \brief Update the functions inside this environment by
* functions in another environment. * functions in another environment.
* \param other The other environment. * \param other The other environment.
...@@ -137,6 +171,11 @@ class ModuleNode : public RelayNode { ...@@ -137,6 +171,11 @@ class ModuleNode : public RelayNode {
* ensures global uniqueness. * ensures global uniqueness.
*/ */
tvm::Map<std::string, GlobalVar> global_var_map_; tvm::Map<std::string, GlobalVar> global_var_map_;
/*! \brief A map from string names to global type variables (ADT names)
* that ensures global uniqueness.
*/
tvm::Map<std::string, GlobalTypeVar> global_type_var_map_;
}; };
struct Module : public NodeRef { struct Module : public NodeRef {
......
...@@ -422,7 +422,7 @@ inline OpRegistry& OpRegistry::add_type_rel( ...@@ -422,7 +422,7 @@ inline OpRegistry& OpRegistry::add_type_rel(
std::string input_name_prefix = "in"; std::string input_name_prefix = "in";
for (int i = 0; i < get()->num_inputs; i++) { for (int i = 0; i < get()->num_inputs; i++) {
auto name = input_name_prefix + std::to_string(i); auto name = input_name_prefix + std::to_string(i);
auto param = TypeVarNode::make(name, TypeVarNode::Kind::kType); auto param = TypeVarNode::make(name, Kind::kType);
type_params.push_back(param); type_params.push_back(param);
arg_types.push_back(param); arg_types.push_back(param);
} }
...@@ -430,7 +430,7 @@ inline OpRegistry& OpRegistry::add_type_rel( ...@@ -430,7 +430,7 @@ inline OpRegistry& OpRegistry::add_type_rel(
Array<Type> ty_call_args = arg_types; Array<Type> ty_call_args = arg_types;
// Add output type. // Add output type.
auto out_param = TypeVarNode::make("out", TypeVarNode::Kind::kType); auto out_param = TypeVarNode::make("out", Kind::kType);
type_params.push_back(out_param); type_params.push_back(out_param);
// this will trigger copy on write. // this will trigger copy on write.
ty_call_args.push_back(out_param); ty_call_args.push_back(out_param);
......
...@@ -56,9 +56,9 @@ TVM_DLL Function InferType(const Function& f, const Module& mod, ...@@ -56,9 +56,9 @@ TVM_DLL Function InferType(const Function& f, const Module& mod,
* \param t The type to check. * \param t The type to check.
* \param mod The global module. * \param mod The global module.
* *
* \return true if the rules are satisified otherwise false * \return The kind of the passed type.
*/ */
TVM_DLL bool KindCheck(const Type& t, const Module& mod); TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
/*! \brief Compare two expressions for structural equivalence. /*! \brief Compare two expressions for structural equivalence.
* *
...@@ -144,10 +144,11 @@ TVM_DLL tvm::Array<Var> AllVars(const Expr& expr); ...@@ -144,10 +144,11 @@ TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
* type in the context. * type in the context.
* *
* \param expr the expression. * \param expr the expression.
* \param mod the module.
* *
* \return List of free vars, in the PostDFS order visited by expr. * \return List of free vars, in the PostDFS order visited by expr.
*/ */
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr); TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
/*! \brief Get free TypeVars from type t. /*! \brief Get free TypeVars from type t.
* *
...@@ -155,10 +156,11 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr); ...@@ -155,10 +156,11 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
* type in the context. * type in the context.
* *
* \param t the type. * \param t the type.
* \param mod the module.
* *
* \return List of free type vars, in the PostDFS order visited by type. * \return List of free type vars, in the PostDFS order visited by type.
*/ */
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t); TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
/*! \brief Get all bound type variables from expression expr. /*! \brief Get all bound type variables from expression expr.
* *
...@@ -166,10 +168,11 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t); ...@@ -166,10 +168,11 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t);
* They only have meaning inside that expr, and can only be used in it. * They only have meaning inside that expr, and can only be used in it.
* *
* \param expr the expression. * \param expr the expression.
* \param mod the module.
* *
* \return List of bound type vars, in the PostDFS order in the expression. * \return List of bound type vars, in the PostDFS order in the expression.
*/ */
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr); TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
/*! \brief Get all bound type variables from type t. /*! \brief Get all bound type variables from type t.
* *
...@@ -177,26 +180,29 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr); ...@@ -177,26 +180,29 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr);
* They only have meaning inside that type, and can only be used in it. * They only have meaning inside that type, and can only be used in it.
* *
* \param t the type * \param t the type
* \param mod the module.
* *
* \return List of bound type vars, in the PostDFS order visited by type. * \return List of bound type vars, in the PostDFS order visited by type.
*/ */
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t); TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
/*! \brief Get all type variables in expression expr. /*! \brief Get all type variables in expression expr.
* *
* \param expr the expression. * \param expr the expression.
* \param mod the module.
* *
* \return List of type vars, in the PostDFS order in the expression. * \return List of type vars, in the PostDFS order in the expression.
*/ */
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr); TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
/*! \brief Get all type variables in type t. /*! \brief Get all type variables in type t.
* *
* \param t the type. * \param t the type.
* \param mod the module.
* *
* \return List of type vars, in the PostDFS order visited by type. * \return List of type vars, in the PostDFS order visited by type.
*/ */
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t); TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
/*! \brief Remove expressions which does not effect the program result. /*! \brief Remove expressions which does not effect the program result.
* *
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/pattern_functor.h
* \brief A more powerful visitor on ADT patterns that enables defining
* arbitrary function signatures with type-based dispatch on first argument.
*/
#ifndef TVM_RELAY_PATTERN_FUNCTOR_H_
#define TVM_RELAY_PATTERN_FUNCTOR_H_
#include <tvm/node/ir_functor.h>
#include <string>
#include "./expr.h"
#include "./op.h"
#include "./error.h"
#include "./adt.h"
namespace tvm {
namespace relay {
/*!
* \brief A dynamical functor on ADT patterns that dispatches on its first argument.
* You can use this as a more powerful visitor, since it allows you to
* define the types of further arguments to VisitPattern.
*
* \sa tvm/ir_functor.h
*
* \tparam FType function signiture
* This type is only defined for FType with function signature R(const Pattern&,
* Args...)
*/
template <typename FType>
class PatternFunctor;
// functions to be overriden.
#define PATTERN_FUNCTOR_DEFAULT \
{ return VisitPatternDefault_(op, std::forward<Args>(args)...); }
#define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitPattern_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
});
template <typename R, typename... Args>
class PatternFunctor<R(const Pattern& n, Args...)> {
private:
using TSelf = PatternFunctor<R(const Pattern& n, Args...)>;
using FType = tvm::IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*! \brief virtual destructor */
virtual ~PatternFunctor() {}
/*!
* \brief Same as call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
R operator()(const Pattern& n, Args... args) {
return VisitPattern(n, std::forward<Args>(args)...);
}
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitPattern(const Pattern& n, Args... args) {
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
virtual R VisitPattern_(const PatternWildcardNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPattern_(const PatternVarNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPattern_(const PatternConstructorNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPatternDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
// Set dispatch
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternWildcardNode);
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternVarNode);
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternConstructorNode);
return vtable;
}
};
/*! \brief A simple visitor wrapper around PatternFunctor.
*
* Exposes two visitors with default traversal strategies, one
* which doesn't compute a result but can mutate internal state,
* and another which functionally builds a new pattern.
*/
class PatternVisitor : public ::tvm::relay::PatternFunctor<void(const Pattern& n)> {
public:
void VisitPattern_(const PatternWildcardNode* op) override;
void VisitPattern_(const PatternVarNode* op) override;
void VisitPattern_(const PatternConstructorNode* op) override;
virtual void VisitType(const Type& t);
virtual void VisitVar(const Var& v);
virtual void VisitConstructor(const Constructor& c);
};
/*! \brief A wrapper around ExprFunctor which functionally updates the AST.
*
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/
class PatternMutator
: public ::tvm::relay::PatternFunctor<Pattern(const Pattern&)> {
public:
Pattern Mutate(const Pattern& pat);
Pattern VisitPattern_(const PatternWildcardNode* op) override;
Pattern VisitPattern_(const PatternVarNode* op) override;
Pattern VisitPattern_(const PatternConstructorNode* op) override;
/*! \brief Used to visit the types inside of patterns.
*
* Can be overloaded to transform the types in arbitrary
* ways, one way would be to define a sub-class of type
* visitor for types which transform them appropriately.
*/
virtual Type VisitType(const Type& t);
/*! \brief Used to visit the vars inside of patterns. */
virtual Var VisitVar(const Var& v);
/*! \brief Used to visit the vars inside of patterns. */
virtual Constructor VisitConstructor(const Constructor& c);
private:
std::unordered_map<Var, Var, NodeHash, NodeEqual> var_map_;
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PATTERN_FUNCTOR_H_
...@@ -98,6 +98,18 @@ class TensorTypeNode : public BaseTensorTypeNode { ...@@ -98,6 +98,18 @@ class TensorTypeNode : public BaseTensorTypeNode {
RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
/*! \brief possible kinds of Type */
enum Kind : int {
/*! \brief template variable in shape expression */
kType = 0,
kShapeVar = 1,
kBaseType = 2,
kShape = 3,
kConstraint = 4,
kAdtHandle = 5,
kTypeData = 6
};
/*! /*!
* \brief Type parameter in the function. * \brief Type parameter in the function.
* This can be viewed as template parameter in c++ template function. * This can be viewed as template parameter in c++ template function.
...@@ -119,14 +131,6 @@ class TypeVar; ...@@ -119,14 +131,6 @@ class TypeVar;
/*! \brief TypeVar container node */ /*! \brief TypeVar container node */
class TypeVarNode : public TypeNode { class TypeVarNode : public TypeNode {
public: public:
/*! \brief possible kinds of TypeVar */
enum Kind : int {
/*! \brief template variable in shape expression */
kType = 0,
kShapeVar = 1,
kBaseType = 2,
kShape = 3
};
/*! /*!
* \brief The variable itself is only meaningful when * \brief The variable itself is only meaningful when
* kind is ShapeVar, otherwise, we only use the name. * kind is ShapeVar, otherwise, we only use the name.
...@@ -150,6 +154,63 @@ class TypeVarNode : public TypeNode { ...@@ -150,6 +154,63 @@ class TypeVarNode : public TypeNode {
RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type); RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type);
/*! /*!
* \brief A global type variable that is used for defining new types or type aliases.
*/
class GlobalTypeVar;
/*! \brief GlobalTypeVar container node */
class GlobalTypeVarNode : public TypeNode {
public:
/*!
* \brief The variable itself is only meaningful when
* kind is ShapeVar; otherwise, we only use the name.
*/
tvm::Var var;
/*! \brief The kind of type parameter */
Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("kind", &kind);
v->Visit("span", &span);
}
TVM_DLL static GlobalTypeVar make(std::string name, Kind kind);
static constexpr const char* _type_key = "relay.GlobalTypeVar";
TVM_DECLARE_NODE_TYPE_INFO(GlobalTypeVarNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(GlobalTypeVar, GlobalTypeVarNode, Type);
/*!
* \brief Type application.
*/
class TypeCall;
/*! \brief TypeCall container node */
class TypeCallNode : public TypeNode {
public:
/*!
* \brief The type-level function (ADT that takes type params).
*/
Type func;
/*! \brief The arguments. */
tvm::Array<Type> args;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("func", &func);
v->Visit("args", &args);
v->Visit("span", &span);
}
TVM_DLL static TypeCall make(Type func, tvm::Array<Type> args);
static constexpr const char* _type_key = "relay.TypeCall";
TVM_DECLARE_NODE_TYPE_INFO(TypeCallNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type);
/*!
* \brief IncompleteType. * \brief IncompleteType.
* This is intermediate values that is used during type inference. * This is intermediate values that is used during type inference.
* *
...@@ -162,14 +223,14 @@ class IncompleteType; ...@@ -162,14 +223,14 @@ class IncompleteType;
/*! \brief IncompleteType container node */ /*! \brief IncompleteType container node */
class IncompleteTypeNode : public TypeNode { class IncompleteTypeNode : public TypeNode {
public: public:
TypeVarNode::Kind kind; Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("kind", &kind); v->Visit("kind", &kind);
v->Visit("span", &span); v->Visit("span", &span);
} }
TVM_DLL static IncompleteType make(TypeVarNode::Kind kind); TVM_DLL static IncompleteType make(Kind kind);
static constexpr const char* _type_key = "relay.IncompleteType"; static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode); TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode);
......
...@@ -7,8 +7,10 @@ from . import ty ...@@ -7,8 +7,10 @@ from . import ty
from . import expr from . import expr
from . import expr_functor from . import expr_functor
from . import module from . import module
from . import adt
from . import ir_pass from . import ir_pass
from .build_module import build, build_config, create_executor, optimize from .build_module import build, build_config, create_executor, optimize
from . import prelude
from . import parser from . import parser
from . import debug from . import debug
...@@ -45,6 +47,8 @@ TypeRelation = ty.TypeRelation ...@@ -45,6 +47,8 @@ TypeRelation = ty.TypeRelation
IncompleteType = ty.IncompleteType IncompleteType = ty.IncompleteType
scalar_type = ty.scalar_type scalar_type = ty.scalar_type
RefType = ty.RefType RefType = ty.RefType
GlobalTypeVar = ty.GlobalTypeVar
TypeCall = ty.TypeCall
# Expr # Expr
Expr = expr.Expr Expr = expr.Expr
...@@ -61,6 +65,15 @@ RefCreate = expr.RefCreate ...@@ -61,6 +65,15 @@ RefCreate = expr.RefCreate
RefRead = expr.RefRead RefRead = expr.RefRead
RefWrite = expr.RefWrite RefWrite = expr.RefWrite
# ADT
PatternWildcard = adt.PatternWildcard
PatternVar = adt.PatternVar
PatternConstructor = adt.PatternConstructor
Constructor = adt.Constructor
TypeData = adt.TypeData
Clause = adt.Clause
Match = adt.Match
# helper functions # helper functions
var = expr.var var = expr.var
const = expr.const const = expr.const
......
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""Algebraic data types in Relay."""
from .base import RelayNode, register_relay_node, NodeBase
from . import _make
from .ty import Type
from .expr import Expr, Call
class Pattern(RelayNode):
"""Base type for pattern matching constructs."""
pass
@register_relay_node
class PatternWildcard(Pattern):
"""Wildcard pattern in Relay: Matches any ADT and binds nothing."""
def __init__(self):
"""Constructs a wildcard pattern.
Parameters
----------
None
Returns
-------
wildcard: PatternWildcard
a wildcard pattern.
"""
self.__init_handle_by_constructor__(_make.PatternWildcard)
@register_relay_node
class PatternVar(Pattern):
"""Variable pattern in Relay: Matches anything and binds it to the variable."""
def __init__(self, var):
"""Construct a variable pattern.
Parameters
----------
var: tvm.relay.Var
Returns
-------
pv: PatternVar
A variable pattern.
"""
self.__init_handle_by_constructor__(_make.PatternVar, var)
@register_relay_node
class PatternConstructor(Pattern):
"""Constructor pattern in Relay: Matches an ADT of the given constructor, binds recursively."""
def __init__(self, constructor, patterns=None):
"""Construct a constructor pattern.
Parameters
----------
constructor: Constructor
The constructor.
patterns: Optional[List[Pattern]]
Optional subpatterns: for each field of the constructor,
match to the given subpattern (treated as a variable pattern by default).
Returns
-------
wildcard: PatternWildcard
a wildcard pattern.
"""
if patterns is None:
patterns = []
self.__init_handle_by_constructor__(_make.PatternConstructor, constructor, patterns)
@register_relay_node
class Constructor(Expr):
"""Relay ADT constructor."""
def __init__(self, name_hint, inputs, belong_to):
"""Defines an ADT constructor.
Parameters
----------
name_hint : str
Name of constructor (only a hint).
inputs : List[Type]
Input types.
belong_to : tvm.relay.GlobalTypeVar
Denotes which ADT the constructor belongs to.
Returns
-------
con: Constructor
A constructor.
"""
self.__init_handle_by_constructor__(_make.Constructor, name_hint, inputs, belong_to)
def __call__(self, *args):
"""Call the constructor.
Parameters
----------
args: List[relay.Expr]
The arguments to the constructor.
Returns
-------
call: relay.Call
A call to the constructor.
"""
return Call(self, args)
@register_relay_node
class TypeData(Type):
"""Stores the definition for an Algebraic Data Type (ADT) in Relay.
Note that ADT definitions are treated as type-level functions because
the type parameters need to be given for an instance of the ADT. Thus,
any global type var that is an ADT header needs to be wrapped in a
type call that passes in the type params.
"""
def __init__(self, header, type_vars, constructors):
"""Defines a TypeData object.
Parameters
----------
header: tvm.relay.GlobalTypeVar
The name of the ADT.
ADTs with the same constructors but different names are
treated as different types.
type_vars: List[TypeVar]
Type variables that appear in constructors.
constructors: List[tvm.relay.Constructor]
The constructors for the ADT.
Returns
-------
type_data: TypeData
The adt declaration.
"""
self.__init_handle_by_constructor__(_make.TypeData, header, type_vars, constructors)
@register_relay_node
class Clause(NodeBase):
"""Clause for pattern matching in Relay."""
def __init__(self, lhs, rhs):
"""Construct a clause.
Parameters
----------
lhs: tvm.relay.Pattern
Left-hand side of match clause.
rhs: tvm.relay.Expr
Right-hand side of match clause.
Returns
-------
clause: Clause
The Clause.
"""
self.__init_handle_by_constructor__(_make.Clause, lhs, rhs)
@register_relay_node
class Match(Expr):
"""Pattern matching expression in Relay."""
def __init__(self, data, clauses):
"""Construct a Match.
Parameters
----------
data: tvm.relay.Expr
The value being deconstructed and matched.
clauses: List[tvm.relay.Clause]
The pattern match clauses.
Returns
-------
match: tvm.relay.Expr
The match expression.
"""
self.__init_handle_by_constructor__(_make.Match, data, clauses)
...@@ -292,6 +292,12 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -292,6 +292,12 @@ class GraphRuntimeCodegen(ExprFunctor):
def visit_ref_write(self, _): def visit_ref_write(self, _):
raise RuntimeError("reference not supported") raise RuntimeError("reference not supported")
def visit_constructor(self, _):
raise Exception("ADT constructor case not yet implemented")
def visit_match(self, _):
raise Exception("match case not yet implemented")
def _get_json(self): def _get_json(self):
""" """
Convert the sequence of nodes stored by the compiler into the Convert the sequence of nodes stored by the compiler into the
......
...@@ -53,6 +53,13 @@ class Closure(Value): ...@@ -53,6 +53,13 @@ class Closure(Value):
@register_relay_node @register_relay_node
class ConstructorValue(Value):
def __init__(self, constructor, fields, types):
self.__init_handle_by_constructor__(
_make.ConstructorValue, constructor, fields, types)
@register_relay_node
class TensorValue(Value): class TensorValue(Value):
"""A Tensor value produced by the interpreter.""" """A Tensor value produced by the interpreter."""
......
...@@ -172,6 +172,20 @@ class Var(Expr): ...@@ -172,6 +172,20 @@ class Var(Expr):
name = self.vid.name_hint name = self.vid.name_hint
return name return name
def __call__(self, *args):
"""Call the variable (if it represents a function).
Parameters
----------
args: List[relay.Expr]
The arguments to the call.
Returns
-------
call: Call
A call taking the variable as a function.
"""
return Call(self, args)
@register_relay_node @register_relay_node
class GlobalVar(Expr): class GlobalVar(Expr):
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
"""The expression functor of Relay.""" """The expression functor of Relay."""
from .expr import Function, Call, Let, Var, GlobalVar, If, Tuple, TupleGetItem, Constant from .expr import Function, Call, Let, Var, GlobalVar, If, Tuple, TupleGetItem, Constant
from .adt import Constructor, Match, Clause
from .op import Op from .op import Op
class ExprFunctor: class ExprFunctor:
...@@ -47,6 +48,10 @@ class ExprFunctor: ...@@ -47,6 +48,10 @@ class ExprFunctor:
res = self.visit_ref_read(expr) res = self.visit_ref_read(expr)
elif isinstance(expr, RefWrite): elif isinstance(expr, RefWrite):
res = self.visit_ref_write(expr) res = self.visit_ref_write(expr)
elif isinstance(expr, Constructor):
res = self.visit_constructor(expr)
elif isinstance(expr, Match):
res = self.visit_match(expr)
else: else:
raise Exception("warning unhandled case: {0}".format(type(expr))) raise Exception("warning unhandled case: {0}".format(type(expr)))
...@@ -96,6 +101,13 @@ class ExprFunctor: ...@@ -96,6 +101,13 @@ class ExprFunctor:
def visit_ref_read(self, _): def visit_ref_read(self, _):
raise NotImplementedError() raise NotImplementedError()
def visit_constructor(self, _):
raise NotImplementedError()
def visit_match(self, _):
raise NotImplementedError()
class ExprMutator(ExprFunctor): class ExprMutator(ExprFunctor):
""" """
A functional visitor over Expr. A functional visitor over Expr.
......
...@@ -9,6 +9,7 @@ from . import _ir_pass ...@@ -9,6 +9,7 @@ from . import _ir_pass
from . import _make from . import _make
from .expr import Expr from .expr import Expr
from .ty import Type from .ty import Type
from .module import Module
def post_order_visit(expr, fvisit): def post_order_visit(expr, fvisit):
"""Recursively visit the ir in post DFS order node, """Recursively visit the ir in post DFS order node,
...@@ -107,7 +108,7 @@ def well_formed(expr): ...@@ -107,7 +108,7 @@ def well_formed(expr):
def check_kind(t, mod=None): def check_kind(t, mod=None):
"""Check that the type is well kinded. """Check that the type is well kinded and return the kind.
For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes. For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes.
Parameters Parameters
...@@ -120,15 +121,15 @@ def check_kind(t, mod=None): ...@@ -120,15 +121,15 @@ def check_kind(t, mod=None):
Returns Returns
------- -------
well_kinded : bool kind : Kind
whether the input type is well kinded. the kind of t
Examples Examples
-------- --------
.. code:: python .. code:: python
assert not check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)])) assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)])) == Shape
assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) == Type
""" """
if mod is not None: if mod is not None:
return _ir_pass.check_kind(t, mod) return _ir_pass.check_kind(t, mod)
...@@ -190,52 +191,61 @@ def all_vars(expr): ...@@ -190,52 +191,61 @@ def all_vars(expr):
return _ir_pass.all_vars(expr) return _ir_pass.all_vars(expr)
def free_type_vars(expr): def free_type_vars(expr, mod=None):
"""Get free type variables from expression/type e """Get free type variables from expression/type e
Parameters Parameters
---------- ----------
expr: Union[tvm.relay.Expr,tvm.relay.Type] expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type The input expression/type
mod: tvm.relay.Module, optional
The global module
Returns Returns
------- -------
free : List[tvm.relay.TypeVar] free : List[tvm.relay.TypeVar]
The list of free type variables in post-DFS order The list of free type variables in post-DFS order
""" """
return _ir_pass.free_type_vars(expr) use_mod = mod if mod is not None else Module()
return _ir_pass.free_type_vars(expr, use_mod)
def bound_type_vars(expr): def bound_type_vars(expr, mod=None):
"""Get bound type variables from expression/type e """Get bound type variables from expression/type e
Parameters Parameters
---------- ----------
expr: Union[tvm.relay.Expr,tvm.relay.Type] expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type The input expression/type
mod: tvm.relay.Module, optional
The global module
Returns Returns
------- -------
free : List[tvm.relay.TypeVar] free : List[tvm.relay.TypeVar]
The list of bound type variables in post-DFS order The list of bound type variables in post-DFS order
""" """
return _ir_pass.bound_type_vars(expr) use_mod = mod if mod is not None else Module()
return _ir_pass.bound_type_vars(expr, use_mod)
def all_type_vars(expr): def all_type_vars(expr, mod=None):
"""Get all type variables from expression/type e """Get all type variables from expression/type e
Parameters Parameters
---------- ----------
expr: Union[tvm.relay.Expr,tvm.relay.Type] expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type The input expression/type
mod: tvm.relay.Module, optional
The global module
Returns Returns
------- -------
free : List[tvm.relay.TypeVar] free : List[tvm.relay.TypeVar]
The list of all type variables in post-DFS order The list of all type variables in post-DFS order
""" """
return _ir_pass.all_type_vars(expr) use_mod = mod if mod is not None else Module()
return _ir_pass.all_type_vars(expr, use_mod)
def simplify_inference(expr): def simplify_inference(expr):
......
...@@ -6,6 +6,7 @@ from . import _make ...@@ -6,6 +6,7 @@ from . import _make
from . import _module from . import _module
from . import expr as _expr from . import expr as _expr
from . import ty as _ty
@register_relay_node @register_relay_node
class Module(RelayNode): class Module(RelayNode):
...@@ -20,7 +21,7 @@ class Module(RelayNode): ...@@ -20,7 +21,7 @@ class Module(RelayNode):
functions : dict, optional. functions : dict, optional.
Map of global var to Function Map of global var to Function
""" """
def __init__(self, functions=None): def __init__(self, functions=None, type_definitions=None):
if functions is None: if functions is None:
functions = {} functions = {}
elif isinstance(functions, dict): elif isinstance(functions, dict):
...@@ -32,28 +33,46 @@ class Module(RelayNode): ...@@ -32,28 +33,46 @@ class Module(RelayNode):
raise TypeError("Expect functions to be Dict[GlobalVar, Function]") raise TypeError("Expect functions to be Dict[GlobalVar, Function]")
mapped_funcs[k] = v mapped_funcs[k] = v
functions = mapped_funcs functions = mapped_funcs
self.__init_handle_by_constructor__(_make.Module, functions) if type_definitions is None:
type_definitions = {}
elif isinstance(type_definitions, dict):
mapped_type_defs = {}
for k, v in type_definitions.items():
if isinstance(k, _base.string_types):
k = _ty.GlobalTypeVar(k)
if not isinstance(k, _ty.GlobalTypeVar):
raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
mapped_type_defs[k] = v
type_definitions = mapped_type_defs
self.__init_handle_by_constructor__(_make.Module, functions, type_definitions)
def __setitem__(self, var, func): def __setitem__(self, var, val):
"""Add a function to the module. """Add a mapping to the module.
Parameters Parameters
--------- ---------
var: GlobalVar var: GlobalVar
The global variable which names the function. The global variable.
func: Function val: Union[Function, Type]
The function. The value.
""" """
return self._add(var, func) return self._add(var, val)
def _add(self, var, func, update=False): def _add(self, var, val, update=False):
if isinstance(val, _expr.Function):
if isinstance(var, _base.string_types): if isinstance(var, _base.string_types):
var = _expr.GlobalVar(var) var = _expr.GlobalVar(var)
return _module.Module_Add(self, var, func, update) _make.Module_Add(self, var, val, update)
else:
assert isinstance(val, _ty.Type)
if isinstance(var, _base.string_types):
var = _ty.GlobalTypeVar(var)
_module.Module_AddDef(self, var, val)
def __getitem__(self, var): def __getitem__(self, var):
"""Lookup a global function by name or by variable. """Lookup a global definition by name or by variable.
Parameters Parameters
---------- ----------
...@@ -62,13 +81,15 @@ class Module(RelayNode): ...@@ -62,13 +81,15 @@ class Module(RelayNode):
Returns Returns
------- -------
func: Function val: Union[Function, Type]
The function referenced by :code:`var`. The definition referenced by :code:`var` (either a function or type).
""" """
if isinstance(var, _base.string_types): if isinstance(var, _base.string_types):
return _module.Module_Lookup_str(self, var) return _module.Module_Lookup_str(self, var)
else: elif isinstance(var, _expr.GlobalVar):
return _module.Module_Lookup(self, var) return _module.Module_Lookup(self, var)
else:
return _module.Module_LookupDef(self, var)
def update(self, other): def update(self, other):
"""Insert functions in another Module to current one. """Insert functions in another Module to current one.
...@@ -100,3 +121,22 @@ class Module(RelayNode): ...@@ -100,3 +121,22 @@ class Module(RelayNode):
tvm.TVMError if we cannot find corresponding global var. tvm.TVMError if we cannot find corresponding global var.
""" """
return _module.Module_GetGlobalVar(self, name) return _module.Module_GetGlobalVar(self, name)
def get_global_type_var(self, name):
"""Get a global type variable in the function by name.
Parameters
----------
name: str
The name of the global type variable.
Returns
-------
global_type_var: GlobalTypeVar
The global variable mapped to :code:`name`.
Raises
------
tvm.TVMError if we cannot find corresponding global type var.
"""
return _module.Module_GetGlobalTypeVar(self, name)
...@@ -21,6 +21,19 @@ class Type(RelayNode): ...@@ -21,6 +21,19 @@ class Type(RelayNode):
"""Compares two Relay types by referential equality.""" """Compares two Relay types by referential equality."""
return super().__eq__(other) return super().__eq__(other)
def __call__(self, *args):
"""Create a type call from this type.
Parameters
----------
args: List[relay.Type]
The arguments to the type call.
Returns
-------
call: relay.TypeCall
"""
return TypeCall(self, args)
@register_relay_node @register_relay_node
class TensorType(Type): class TensorType(Type):
...@@ -75,6 +88,9 @@ class Kind(IntEnum): ...@@ -75,6 +88,9 @@ class Kind(IntEnum):
ShapeVar = 1 ShapeVar = 1
BaseType = 2 BaseType = 2
Shape = 3 Shape = 3
Constraint = 4
AdtHandle = 5
TypeData = 6
@register_relay_node @register_relay_node
class TypeVar(Type): class TypeVar(Type):
...@@ -107,6 +123,53 @@ class TypeVar(Type): ...@@ -107,6 +123,53 @@ class TypeVar(Type):
@register_relay_node @register_relay_node
class GlobalTypeVar(Type):
"""A global type variable in Relay.
GlobalTypeVar is used to refer to the global type-level definitions
stored in the environment.
"""
def __init__(self, var, kind=Kind.AdtHandle):
"""Construct a GlobalTypeVar.
Parameters
----------
var: tvm.Var
The tvm.Var which backs the type parameter.
kind: Kind, optional
The kind of the type parameter, Kind.AdtHandle by default.
Returns
-------
type_var: GlobalTypeVar
The global type variable.
"""
self.__init_handle_by_constructor__(_make.GlobalTypeVar, var, kind)
@register_relay_node
class TypeCall(Type):
"""Type-level function application in Relay.
A type call applies argument types to a constructor (type-level function).
"""
def __init__(self, func, args):
"""Construct a TypeCall.
Parameters
----------
func: tvm.relay.Type
The function.
args: List[tvm.expr.Type]
The arguments.
Returns
-------
type_call: TypeCall
The type function application.
"""
self.__init_handle_by_constructor__(_make.TypeCall, func, args)
@register_relay_node
class TypeConstraint(Type): class TypeConstraint(Type):
"""Abstract class representing a type constraint.""" """Abstract class representing a type constraint."""
pass pass
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h> #include <tvm/relay/interpreter.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/attrs/debug.h> #include <tvm/relay/attrs/debug.h>
...@@ -92,6 +93,26 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -92,6 +93,26 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "RefValueNode(" << node->value << ")"; p->stream << "RefValueNode(" << node->value << ")";
}); });
ConstructorValue ConstructorValueNode::make(Constructor constructor,
tvm::Array<Value> fields) {
NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
n->constructor = constructor;
n->fields = fields;
return ConstructorValue(n);
}
TVM_REGISTER_API("relay._make.ConstructorValue")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ConstructorValueNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
tvm::IRPrinter* p) {
p->stream << "ConstructorValueNode(" << node->constructor
<< node->fields << ")";
});
/*! /*!
* \brief A stack frame in the Relay interpreter. * \brief A stack frame in the Relay interpreter.
* *
...@@ -185,7 +206,8 @@ InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) { ...@@ -185,7 +206,8 @@ InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) {
// //
// Conversion to ANF is recommended before running the interpretation. // Conversion to ANF is recommended before running the interpretation.
class Interpreter : class Interpreter :
public ExprFunctor<Value(const Expr& n)> { public ExprFunctor<Value(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const Value& v)> {
public: public:
Interpreter(Module mod, Interpreter(Module mod,
DLContext context, DLContext context,
...@@ -209,7 +231,7 @@ class Interpreter : ...@@ -209,7 +231,7 @@ class Interpreter :
} }
Value Eval(const Expr& expr) { Value Eval(const Expr& expr) {
return (*this)(expr); return VisitExpr(expr);
} }
Value VisitExpr(const Expr& expr) final { Value VisitExpr(const Expr& expr) final {
...@@ -401,6 +423,9 @@ class Interpreter : ...@@ -401,6 +423,9 @@ class Interpreter :
<< "; operators should be removed by future passes; try " << "; operators should be removed by future passes; try "
"fusing and lowering"; "fusing and lowering";
} }
if (auto con = call->op.as<ConstructorNode>()) {
return ConstructorValueNode::make(GetRef<Constructor>(con), args);
}
// Now we just evaluate and expect to find a closure. // Now we just evaluate and expect to find a closure.
Value fn_val = Eval(call->op); Value fn_val = Eval(call->op);
if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) { if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
...@@ -474,6 +499,44 @@ class Interpreter : ...@@ -474,6 +499,44 @@ class Interpreter :
} }
} }
Value VisitExpr_(const MatchNode* op) final {
Value v = Eval(op->data);
for (const Clause& c : op->clauses) {
if (VisitPattern(c->lhs, v)) {
return VisitExpr(c->rhs);
}
}
LOG(FATAL) << "did not find any match";
return Value();
}
bool VisitPattern_(const PatternConstructorNode* op, const Value& v) final {
const ConstructorValueNode* cvn = v.as<ConstructorValueNode>();
CHECK(cvn) << "need to be a constructor for match";
CHECK_NE(op->constructor->tag, -1);
CHECK_NE(cvn->constructor->tag, -1);
if (op->constructor->tag == cvn->constructor->tag) {
// todo(M.K.): should use ptr equality but it is broken
CHECK(op->patterns.size() == cvn->fields.size());
for (size_t i = 0; i < op->patterns.size(); ++i) {
if (!VisitPattern(op->patterns[i], cvn->fields[i])) {
return false;
}
}
return true;
}
return false;
}
bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final {
return true;
}
bool VisitPattern_(const PatternVarNode* op, const Value& v) final {
extend(op->var, v);
return true;
}
InterpreterState get_state(Expr e = Expr()) const { InterpreterState get_state(Expr e = Expr()) const {
InterpreterStateNode::Stack stack; InterpreterStateNode::Stack stack;
for (auto fr : this->stack_.frames) { for (auto fr : this->stack_.frames) {
...@@ -485,14 +548,14 @@ class Interpreter : ...@@ -485,14 +548,14 @@ class Interpreter :
} }
private: private:
// module // Module
Module mod_; Module mod_;
// For simplicity we only run the interpreter on a single context. // For simplicity we only run the interpreter on a single context.
// Context to run the interpreter on. // Context to run the interpreter on.
DLContext context_; DLContext context_;
// Target parameter being used by the interpreter. // Target parameter being used by the interpreter.
Target target_; Target target_;
// value stack. // Value stack.
Stack stack_; Stack stack_;
// Backend compile engine. // Backend compile engine.
CompileEngine engine_; CompileEngine engine_;
......
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/ir/adt.cc
* \brief AST nodes for Relay algebraic data types (ADTs).
*/
#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>
namespace tvm {
namespace relay {
PatternWildcard PatternWildcardNode::make() {
NodePtr<PatternWildcardNode> n = make_node<PatternWildcardNode>();
return PatternWildcard(n);
}
TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
TVM_REGISTER_API("relay._make.PatternWildcard")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = PatternWildcardNode::make();
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternWildcardNode>([](const PatternWildcardNode* node,
tvm::IRPrinter* p) {
p->stream << "PatternWildcardNode()";
});
PatternVar PatternVarNode::make(tvm::relay::Var var) {
NodePtr<PatternVarNode> n = make_node<PatternVarNode>();
n->var = std::move(var);
return PatternVar(n);
}
TVM_REGISTER_NODE_TYPE(PatternVarNode);
TVM_REGISTER_API("relay._make.PatternVar")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = PatternVarNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternVarNode>([](const PatternVarNode* node,
tvm::IRPrinter* p) {
p->stream << "PatternVarNode(" << node->var << ")";
});
PatternConstructor PatternConstructorNode::make(Constructor constructor,
tvm::Array<Pattern> patterns) {
NodePtr<PatternConstructorNode> n = make_node<PatternConstructorNode>();
n->constructor = std::move(constructor);
n->patterns = std::move(patterns);
return PatternConstructor(n);
}
TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
TVM_REGISTER_API("relay._make.PatternConstructor")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = PatternConstructorNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternConstructorNode>([](const PatternConstructorNode* node,
tvm::IRPrinter* p) {
p->stream << "PatternConstructorNode(" << node->constructor
<< ", " << node->patterns << ")";
});
Constructor ConstructorNode::make(std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to) {
NodePtr<ConstructorNode> n = make_node<ConstructorNode>();
n->name_hint = std::move(name_hint);
n->inputs = std::move(inputs);
n->belong_to = std::move(belong_to);
return Constructor(n);
}
TVM_REGISTER_NODE_TYPE(ConstructorNode);
TVM_REGISTER_API("relay._make.Constructor")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ConstructorNode::make(args[0], args[1], args[2]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstructorNode>([](const ConstructorNode* node,
tvm::IRPrinter* p) {
p->stream << "ConstructorNode(" << node->name_hint << ", "
<< node->inputs << ", " << node->belong_to << ")";
});
TypeData TypeDataNode::make(GlobalTypeVar header,
tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors) {
NodePtr<TypeDataNode> n = make_node<TypeDataNode>();
n->header = std::move(header);
n->type_vars = std::move(type_vars);
n->constructors = std::move(constructors);
return TypeData(n);
}
TVM_REGISTER_NODE_TYPE(TypeDataNode);
TVM_REGISTER_API("relay._make.TypeData")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TypeDataNode::make(args[0], args[1], args[2]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeDataNode>([](const TypeDataNode* node,
tvm::IRPrinter* p) {
p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
<< node->constructors << ")";
});
Clause ClauseNode::make(Pattern lhs, Expr rhs) {
NodePtr<ClauseNode> n = make_node<ClauseNode>();
n->lhs = std::move(lhs);
n->rhs = std::move(rhs);
return Clause(n);
}
TVM_REGISTER_NODE_TYPE(ClauseNode);
TVM_REGISTER_API("relay._make.Clause")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ClauseNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ClauseNode>([](const ClauseNode* node,
tvm::IRPrinter* p) {
p->stream << "ClauseNode(" << node->lhs << ", "
<< node->rhs << ")";
});
Match MatchNode::make(Expr data, tvm::Array<Clause> clauses) {
NodePtr<MatchNode> n = make_node<MatchNode>();
n->data = std::move(data);
n->clauses = std::move(clauses);
return Match(n);
}
TVM_REGISTER_NODE_TYPE(MatchNode);
TVM_REGISTER_API("relay._make.Match")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = MatchNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<MatchNode>([](const MatchNode* node,
tvm::IRPrinter* p) {
p->stream << "MatchNode(" << node->data << ", "
<< node->clauses << ")";
});
} // namespace relay
} // namespace tvm
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h> #include <tvm/runtime/ndarray.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include "type_functor.h" #include "type_functor.h"
...@@ -17,7 +18,8 @@ namespace relay { ...@@ -17,7 +18,8 @@ namespace relay {
class AlphaEqualHandler: class AlphaEqualHandler:
public AttrsEqualHandler, public AttrsEqualHandler,
public TypeFunctor<bool(const Type&, const Type&)>, public TypeFunctor<bool(const Type&, const Type&)>,
public ExprFunctor<bool(const Expr&, const Expr&)> { public ExprFunctor<bool(const Expr&, const Expr&)>,
public PatternFunctor<bool(const Pattern&, const Pattern&)> {
public: public:
explicit AlphaEqualHandler(bool map_free_var) explicit AlphaEqualHandler(bool map_free_var)
: map_free_var_(map_free_var) {} : map_free_var_(map_free_var) {}
...@@ -160,7 +162,7 @@ class AlphaEqualHandler: ...@@ -160,7 +162,7 @@ class AlphaEqualHandler:
} }
equal_map_[lhs->type_params[i]] = rhs->type_params[i]; equal_map_[lhs->type_params[i]] = rhs->type_params[i];
// set up type parameter equal // set up type parameter equal
if (lhs->type_params[i]->kind == TypeVarNode::Kind::kShapeVar) { if (lhs->type_params[i]->kind == Kind::kShapeVar) {
// map variable // map variable
equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var; equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var;
} }
...@@ -215,6 +217,26 @@ class AlphaEqualHandler: ...@@ -215,6 +217,26 @@ class AlphaEqualHandler:
return false; return false;
} }
bool VisitType_(const GlobalTypeVarNode* op, const Type& t2) final {
return GetRef<Type>(op) == t2;
}
bool VisitType_(const TypeCallNode* op, const Type& t2) final {
const TypeCallNode* pt = t2.as<TypeCallNode>();
if (pt == nullptr
|| op->args.size() != pt->args.size()
|| !TypeEqual(op->func, pt->func)) {
return false;
}
for (size_t i = 0; i < op->args.size(); ++i) {
if (!TypeEqual(op->args[i], pt->args[i])) {
return false;
}
}
return true;
}
// Expr equal checking. // Expr equal checking.
bool NDArrayEqual(const runtime::NDArray& lhs, bool NDArrayEqual(const runtime::NDArray& lhs,
const runtime::NDArray& rhs) { const runtime::NDArray& rhs) {
...@@ -261,11 +283,9 @@ class AlphaEqualHandler: ...@@ -261,11 +283,9 @@ class AlphaEqualHandler:
bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final { bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final {
if (const GlobalVarNode* rhs = other.as<GlobalVarNode>()) { if (const GlobalVarNode* rhs = other.as<GlobalVarNode>()) {
// use name equality for global var for now. // use name equality for global var for now.
if (lhs->name_hint != rhs->name_hint) return false; return lhs->name_hint == rhs->name_hint;
return true;
} else {
return false;
} }
return false;
} }
bool VisitExpr_(const TupleNode* lhs, const Expr& other) final { bool VisitExpr_(const TupleNode* lhs, const Expr& other) final {
...@@ -392,6 +412,63 @@ class AlphaEqualHandler: ...@@ -392,6 +412,63 @@ class AlphaEqualHandler:
return false; return false;
} }
} }
bool VisitExpr_(const ConstructorNode* op, const Expr& e2) final {
return GetRef<Expr>(op) == e2;
}
bool ClauseEqual(const Clause& l, const Clause& r) {
return PatternEqual(l->lhs, r->lhs) && ExprEqual(l->rhs, r->rhs);
}
bool PatternEqual(const Pattern& l, const Pattern& r) {
return VisitPattern(l, r);
}
bool VisitPattern_(const PatternWildcardNode* op, const Pattern& r) final {
return r.as<PatternWildcardNode>();
}
bool VisitPattern_(const PatternVarNode* op, const Pattern& e2) final {
if (const auto* r = e2.as<PatternVarNode>()) {
return MergeVarDecl(op->var, r->var);
}
return false;
}
bool VisitPattern_(const PatternConstructorNode* op, const Pattern& e2) final {
const auto* r = e2.as<PatternConstructorNode>();
if (r == nullptr
|| !ExprEqual(op->constructor, r->constructor)
|| op->patterns.size() != r->patterns.size()) {
return false;
}
for (size_t i = 0; i < op->patterns.size(); i++) {
if (!PatternEqual(op->patterns[i], r->patterns[i])) {
return false;
}
}
return true;
}
bool VisitExpr_(const MatchNode* op, const Expr& e2) final {
const MatchNode* r = e2.as<MatchNode>();
if (r == nullptr
|| !ExprEqual(op->data, r->data)
|| op->clauses.size() != r->clauses.size()) {
return false;
}
for (size_t i = 0; i < op->clauses.size(); ++i) {
if (!ClauseEqual(op->clauses[i], r->clauses[i])) {
return false;
}
}
return true;
}
private: private:
// whether to map open terms. // whether to map open terms.
bool map_free_var_{false}; bool map_free_var_{false};
......
...@@ -130,9 +130,14 @@ Function FunctionNode::make(tvm::Array<Var> params, ...@@ -130,9 +130,14 @@ Function FunctionNode::make(tvm::Array<Var> params,
FuncType FunctionNode::func_type_annotation() const { FuncType FunctionNode::func_type_annotation() const {
Array<Type> param_types; Array<Type> param_types;
for (auto param : this->params) { for (auto param : this->params) {
param_types.push_back(param->type_annotation); Type param_type = (param->type_annotation.defined()) ? param->type_annotation
: IncompleteTypeNode::make(Kind::kType);
param_types.push_back(param_type);
} }
return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {});
Type ret_type = (this->ret_type.defined()) ? this->ret_type
: IncompleteTypeNode::make(Kind::kType);
return FuncTypeNode::make(param_types, ret_type, this->type_params, {});
} }
bool FunctionNode::IsPrimitive() const { bool FunctionNode::IsPrimitive() const {
......
...@@ -185,6 +185,24 @@ Expr ExprMutator::VisitExpr_(const RefWriteNode* op) { ...@@ -185,6 +185,24 @@ Expr ExprMutator::VisitExpr_(const RefWriteNode* op) {
} }
} }
Expr ExprMutator::VisitExpr_(const ConstructorNode* c) {
return GetRef<Expr>(c);
}
Expr ExprMutator::VisitExpr_(const MatchNode* m) {
std::vector<Clause> clauses;
for (const Clause& p : m->clauses) {
clauses.push_back(VisitClause(p));
}
return MatchNode::make(VisitExpr(m->data), clauses);
}
Clause ExprMutator::VisitClause(const Clause& c) {
return ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs));
}
Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
Type ExprMutator::VisitType(const Type& t) { return t; } Type ExprMutator::VisitType(const Type& t) { return t; }
void ExprVisitor::VisitExpr(const Expr& expr) { void ExprVisitor::VisitExpr(const Expr& expr) {
...@@ -267,6 +285,27 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) { ...@@ -267,6 +285,27 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) {
this->VisitExpr(op->value); this->VisitExpr(op->value);
} }
void ExprVisitor::VisitExpr_(const ConstructorNode* op) {
for (const Type& t : op->inputs) {
this->VisitType(t);
}
this->VisitType(op->belong_to);
}
void ExprVisitor::VisitExpr_(const MatchNode* op) {
this->VisitExpr(op->data);
for (const Clause& c : op->clauses) {
this->VisitClause(c);
}
}
void ExprVisitor::VisitClause(const Clause& op) {
this->VisitPattern(op->lhs);
this->VisitExpr(op->rhs);
}
void ExprVisitor::VisitPattern(const Pattern& p) { return; }
void ExprVisitor::VisitType(const Type& t) { return; } void ExprVisitor::VisitType(const Type& t) { return; }
// visitor to implement apply // visitor to implement apply
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h> #include <tvm/runtime/ndarray.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/attrs.h> #include <tvm/attrs.h>
...@@ -18,7 +19,8 @@ namespace relay { ...@@ -18,7 +19,8 @@ namespace relay {
class RelayHashHandler: class RelayHashHandler:
public AttrsHashHandler, public AttrsHashHandler,
public TypeFunctor<size_t(const Type&)>, public TypeFunctor<size_t(const Type&)>,
public ExprFunctor<size_t(const Expr&)> { public ExprFunctor<size_t(const Expr&)>,
public PatternFunctor<size_t(const Pattern&)> {
public: public:
explicit RelayHashHandler() {} explicit RelayHashHandler() {}
...@@ -201,7 +203,7 @@ class RelayHashHandler: ...@@ -201,7 +203,7 @@ class RelayHashHandler:
hash_map_[var] = hash; hash_map_[var] = hash;
const auto* ty_param = var.as<TypeVarNode>(); const auto* ty_param = var.as<TypeVarNode>();
if (ty_param && ty_param->kind == TypeVarNode::Kind::kShapeVar) { if (ty_param && ty_param->kind == Kind::kShapeVar) {
hash_map_[ty_param->var] = hash; hash_map_[ty_param->var] = hash;
} }
return hash; return hash;
...@@ -249,6 +251,10 @@ class RelayHashHandler: ...@@ -249,6 +251,10 @@ class RelayHashHandler:
hash = Combine(hash, ExprHash(arg)); hash = Combine(hash, ExprHash(arg));
} }
for (auto t : call->type_args) {
hash = Combine(hash, TypeHash(t));
}
hash = Combine(hash, AttrHash(call->attrs)); hash = Combine(hash, AttrHash(call->attrs));
return hash; return hash;
...@@ -304,6 +310,72 @@ class RelayHashHandler: ...@@ -304,6 +310,72 @@ class RelayHashHandler:
hash = Combine(hash, ExprHash(rn->value)); hash = Combine(hash, ExprHash(rn->value));
return hash; return hash;
} }
size_t VisitExpr_(const MatchNode* mn) final {
size_t hash = std::hash<std::string>()(MatchNode::_type_key);
hash = Combine(hash, ExprHash(mn->data));
for (const auto& c : mn->clauses) {
hash = Combine(hash, PatternHash(c->lhs));
hash = Combine(hash, ExprHash(c->rhs));
}
return hash;
}
size_t VisitExpr_(const ConstructorNode* cn) final {
size_t hash = std::hash<std::string>()(ConstructorNode::_type_key);
hash = Combine(hash, std::hash<std::string>()(cn->name_hint));
return hash;
}
size_t VisitType_(const TypeCallNode* tcn) final {
size_t hash = std::hash<std::string>()(TypeCallNode::_type_key);
hash = Combine(hash, TypeHash(tcn->func));
for (const auto& t : tcn->args) {
hash = Combine(hash, TypeHash(t));
}
return hash;
}
size_t VisitType_(const TypeDataNode* tdn) final {
size_t hash = std::hash<std::string>()(TypeDataNode::_type_key);
hash = Combine(hash, TypeHash(tdn->header));
for (const auto& tv : tdn->type_vars) {
hash = Combine(hash, TypeHash(tv));
}
for (const auto& cn : tdn->constructors) {
hash = Combine(hash, ExprHash(cn));
}
return hash;
}
size_t VisitType_(const GlobalTypeVarNode* tvn) final {
return BindVar(GetRef<GlobalTypeVar>(tvn));
}
size_t PatternHash(const Pattern& p) {
return VisitPattern(p);
}
size_t VisitPattern_(const PatternConstructorNode* pcn) final {
size_t hash = std::hash<std::string>()(PatternConstructorNode::_type_key);
hash = Combine(hash, ExprHash(pcn->constructor));
for (const auto& p : pcn->patterns) {
hash = Combine(hash, PatternHash(p));
}
return hash;
}
size_t VisitPattern_(const PatternVarNode* pvn) final {
size_t hash = std::hash<std::string>()(PatternVarNode::_type_key);
hash = Combine(hash, BindVar(pvn->var));
return hash;
}
size_t VisitPattern_(const PatternWildcardNode* pwn) final {
size_t hash = std::hash<std::string>()(PatternWildcardNode::_type_key);
return hash;
}
private: private:
// renaming of NodeRef to indicate two nodes equals to each other // renaming of NodeRef to indicate two nodes equals to each other
std::unordered_map<NodeRef, size_t, NodeHash, NodeEqual> hash_map_; std::unordered_map<NodeRef, size_t, NodeHash, NodeEqual> hash_map_;
......
...@@ -13,18 +13,28 @@ namespace relay { ...@@ -13,18 +13,28 @@ namespace relay {
using tvm::IRPrinter; using tvm::IRPrinter;
using namespace runtime; using namespace runtime;
Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs) { Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs) {
auto n = make_node<ModuleNode>(); auto n = make_node<ModuleNode>();
n->functions = std::move(global_funcs); n->functions = std::move(global_funcs);
n->type_definitions = std::move(global_type_defs);
for (const auto& kv : n->functions) { for (const auto& kv : n->functions) {
// set gloval var map // set global var map
CHECK(!n->global_var_map_.count(kv.first->name_hint)) CHECK(!n->global_var_map_.count(kv.first->name_hint))
<< "Duplicate global function name " << kv.first->name_hint; << "Duplicate global function name " << kv.first->name_hint;
n->global_var_map_.Set(kv.first->name_hint, kv.first); n->global_var_map_.Set(kv.first->name_hint, kv.first);
} }
n->entry_func = GlobalVarNode::make("main"); n->entry_func = GlobalVarNode::make("main");
for (const auto& kv : n->type_definitions) {
// set global typevar map
CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint))
<< "Duplicate global type definition name " << kv.first->var->name_hint;
n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first);
}
return Module(n); return Module(n);
} }
...@@ -51,6 +61,13 @@ void ModuleNode::AddUnchecked(const GlobalVar& var, ...@@ -51,6 +61,13 @@ void ModuleNode::AddUnchecked(const GlobalVar& var,
global_var_map_.Set(var->name_hint, var); global_var_map_.Set(var->name_hint, var);
} }
GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) {
auto it = global_type_var_map_.find(name);
CHECK(it != global_type_var_map_.end())
<< "Cannot find global type var " << name << " in the Module";
return (*it).second;
}
void ModuleNode::Add(const GlobalVar& var, void ModuleNode::Add(const GlobalVar& var,
const Function& func, const Function& func,
bool update) { bool update) {
...@@ -69,6 +86,22 @@ void ModuleNode::Add(const GlobalVar& var, ...@@ -69,6 +86,22 @@ void ModuleNode::Add(const GlobalVar& var,
AddUnchecked(var, checked_func); AddUnchecked(var, checked_func);
} }
void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) {
this->type_definitions.Set(var, type);
// set global type var map
CHECK(!global_type_var_map_.count(var->var->name_hint))
<< "Duplicate global type definition name " << var->var->name_hint;
global_type_var_map_.Set(var->var->name_hint, var);
for (size_t i = 0; i < type->constructors.size(); ++i) {
type->constructors[i]->tag = i;
}
// need to kind check at the end because the check can look up
// a definition potentially
CHECK(KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData)
<< "Invalid or malformed typedata given to module: " << type;
}
void ModuleNode::Update(const GlobalVar& var, const Function& func) { void ModuleNode::Update(const GlobalVar& var, const Function& func) {
this->Add(var, func, true); this->Add(var, func, true);
} }
...@@ -92,6 +125,18 @@ Function ModuleNode::Lookup(const std::string& name) { ...@@ -92,6 +125,18 @@ Function ModuleNode::Lookup(const std::string& name) {
return this->Lookup(id); return this->Lookup(id);
} }
TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) {
auto it = type_definitions.find(var);
CHECK(it != type_definitions.end())
<< "There is no definition of " << var->var->name_hint;
return (*it).second;
}
TypeData ModuleNode::LookupDef(const std::string& name) {
GlobalTypeVar id = this->GetGlobalTypeVar(name);
return this->LookupDef(id);
}
void ModuleNode::Update(const Module& mod) { void ModuleNode::Update(const Module& mod) {
for (auto pair : mod->functions) { for (auto pair : mod->functions) {
this->Update(pair.first, pair.second); this->Update(pair.first, pair.second);
...@@ -101,7 +146,7 @@ void ModuleNode::Update(const Module& mod) { ...@@ -101,7 +146,7 @@ void ModuleNode::Update(const Module& mod) {
Module ModuleNode::FromExpr( Module ModuleNode::FromExpr(
const Expr& expr, const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs) { const tvm::Map<GlobalVar, Function>& global_funcs) {
auto mod = ModuleNode::make(global_funcs); auto mod = ModuleNode::make(global_funcs, {});
auto func_node = expr.as<FunctionNode>(); auto func_node = expr.as<FunctionNode>();
Function func; Function func;
if (func_node) { if (func_node) {
...@@ -117,21 +162,33 @@ TVM_REGISTER_NODE_TYPE(ModuleNode); ...@@ -117,21 +162,33 @@ TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_API("relay._make.Module") TVM_REGISTER_API("relay._make.Module")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ModuleNode::make(args[0]); *ret = ModuleNode::make(args[0], args[1]);
}); });
TVM_REGISTER_API("relay._module.Module_Add") TVM_REGISTER_API("relay._make.Module_Add")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0]; Module mod = args[0];
mod->Add(args[1], args[2], args[3]); mod->Add(args[1], args[2], args[3]);
}); });
TVM_REGISTER_API("relay._module.Module_AddDef")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
mod->AddDef(args[1], args[2]);
});
TVM_REGISTER_API("relay._module.Module_GetGlobalVar") TVM_REGISTER_API("relay._module.Module_GetGlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0]; Module mod = args[0];
*ret = mod->GetGlobalVar(args[1]); *ret = mod->GetGlobalVar(args[1]);
}); });
TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
*ret = mod->GetGlobalTypeVar(args[1]);
});
TVM_REGISTER_API("relay._module.Module_Lookup") TVM_REGISTER_API("relay._module.Module_Lookup")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0]; Module mod = args[0];
...@@ -143,8 +200,21 @@ TVM_REGISTER_API("relay._module.Module_Lookup_str") ...@@ -143,8 +200,21 @@ TVM_REGISTER_API("relay._module.Module_Lookup_str")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0]; Module mod = args[0];
std::string var_name = args[1]; std::string var_name = args[1];
auto var = mod->GetGlobalVar(var_name); *ret = mod->Lookup(var_name);
*ret = mod->Lookup(var); });
TVM_REGISTER_API("relay._module.Module_LookupDef")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
GlobalTypeVar var = args[1];
*ret = mod->LookupDef(var);
});
TVM_REGISTER_API("relay._module.Module_LookupDef_str")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
std::string var_name = args[1];
*ret = mod->LookupDef(var_name);
}); });
TVM_REGISTER_API("relay._module.Module_Update") TVM_REGISTER_API("relay._module.Module_Update")
......
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/pattern_functor.cc
* \brief Implementations of visitors and mutators for ADT patterns.
*/
#include <tvm/relay/pattern_functor.h>
namespace tvm {
namespace relay {
Pattern PatternMutator::Mutate(const Pattern& pat) {
return (*this)(pat);
}
Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) {
return GetRef<Pattern>(op);
}
Pattern PatternMutator::VisitPattern_(const PatternVarNode* op) {
return PatternVarNode::make(VisitVar(op->var));
}
Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) {
std::vector<Pattern> pat;
for (const auto& p : op->patterns) {
pat.push_back(VisitPattern(p));
}
return PatternConstructorNode::make(VisitConstructor(op->constructor), pat);
}
Type PatternMutator::VisitType(const Type& t) {
return t;
}
Var PatternMutator::VisitVar(const Var& v) {
if (var_map_.count(v) == 0) {
var_map_.insert(std::pair<Var, Var>(v,
VarNode::make(v->name_hint(),
VisitType(v->type_annotation))));
}
return var_map_.at(v);
}
Constructor PatternMutator::VisitConstructor(const Constructor& v) {
return v;
}
void PatternVisitor::VisitPattern_(const PatternWildcardNode* op) { }
void PatternVisitor::VisitPattern_(const PatternVarNode* op) {
VisitVar(op->var);
}
void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) {
VisitConstructor(op->constructor);
for (const auto& p : op->patterns) {
VisitPattern(p);
}
}
void PatternVisitor::VisitType(const Type& t) { }
void PatternVisitor::VisitVar(const Var& v) {
VisitType(v->type_annotation);
}
void PatternVisitor::VisitConstructor(const Constructor& c) {
for (const auto& inp : c->inputs) {
VisitType(inp);
}
}
} // namespace relay
} // namespace tvm
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include <tvm/relay/module.h> #include <tvm/relay/module.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <sstream> #include <sstream>
#include "type_functor.h" #include "type_functor.h"
#include "../../lang/attr_functor.h" #include "../../lang/attr_functor.h"
...@@ -23,6 +24,12 @@ struct TextValue { ...@@ -23,6 +24,12 @@ struct TextValue {
TextValue() {} TextValue() {}
// constructor // constructor
explicit TextValue(std::string name) : name(name) {} explicit TextValue(std::string name) : name(name) {}
TextValue operator+(const TextValue& rhs) const {
return TextValue(name + rhs.name);
}
TextValue operator+(const std::string& str) const {
return TextValue(name + str);
}
}; };
// operator overloading // operator overloading
...@@ -128,6 +135,7 @@ class TextMetaDataContext { ...@@ -128,6 +135,7 @@ class TextMetaDataContext {
class TextPrinter : class TextPrinter :
public ExprFunctor<TextValue(const Expr&)>, public ExprFunctor<TextValue(const Expr&)>,
public PatternFunctor<TextValue(const Pattern&)>,
public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*) public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*)
public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*) public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*)
public: public:
...@@ -213,6 +221,9 @@ class TextPrinter : ...@@ -213,6 +221,9 @@ class TextPrinter :
memo_[expr] = val; memo_[expr] = val;
return val; return val;
} }
TextValue GetValue(const Pattern& p) {
return this->VisitPattern(p);
}
//------------------------------------ //------------------------------------
// Overload of Expr printing functions // Overload of Expr printing functions
//------------------------------------ //------------------------------------
...@@ -391,6 +402,36 @@ class TextPrinter : ...@@ -391,6 +402,36 @@ class TextPrinter :
return id; return id;
} }
TextValue VisitExpr_(const MatchNode* op) final {
TextValue data = GetValue(op->data);
this->PrintIndent();
TextValue id = this->AllocTempVar();
stream_ << id << " = " << "Match " << data << " with";
this->PrintEndInst("\n");
for (const auto& c : op->clauses) {
this->PrintIndent();
stream_ << GetValue(c->lhs) << " to " << GetValue(c->rhs);
this->PrintEndInst("\n");
}
return id;
}
TextValue VisitPattern_(const PatternConstructorNode* p) final {
TextValue ret(p->constructor->name_hint + "(");
for (const Pattern& pat : p->patterns) {
ret = ret + " " + GetValue(pat);
}
return ret + ")";
}
TextValue VisitPattern_(const PatternVarNode* pv) final {
return GetValue(pv->var);
}
TextValue VisitExpr_(const ConstructorNode* n) final {
return TextValue(n->name_hint);
}
/*! /*!
* \brief Print the type to os * \brief Print the type to os
* \param type The type to be printed. * \param type The type to be printed.
...@@ -437,6 +478,18 @@ class TextPrinter : ...@@ -437,6 +478,18 @@ class TextPrinter :
VisitTypeDefault_(node, os); VisitTypeDefault_(node, os);
} }
void VisitType_(const TypeCallNode* node, std::ostream& os) final {
os << node->func << "(" << node->args << ")";
}
void VisitType_(const GlobalTypeVarNode* node, std::ostream& os) final {
VisitTypeDefault_(node, os);
}
void VisitType_(const TypeDataNode* node, std::ostream& os) final {
VisitTypeDefault_(node, os);
}
void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*) void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*)
// by default always print as meta-data // by default always print as meta-data
os << meta_.GetMetaNode(GetRef<NodeRef>(node)); os << meta_.GetMetaNode(GetRef<NodeRef>(node));
......
...@@ -48,7 +48,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -48,7 +48,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")"; p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
}); });
TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) { TypeVar TypeVarNode::make(std::string name, Kind kind) {
NodePtr<TypeVarNode> n = make_node<TypeVarNode>(); NodePtr<TypeVarNode> n = make_node<TypeVarNode>();
n->var = tvm::Var(name); n->var = tvm::Var(name);
n->kind = std::move(kind); n->kind = std::move(kind);
...@@ -61,7 +61,7 @@ TVM_REGISTER_API("relay._make.TypeVar") ...@@ -61,7 +61,7 @@ TVM_REGISTER_API("relay._make.TypeVar")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[1]; int kind = args[1];
*ret = *ret =
TypeVarNode::make(args[0], static_cast<TypeVarNode::Kind>(kind)); TypeVarNode::make(args[0], static_cast<Kind>(kind));
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...@@ -71,7 +71,50 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -71,7 +71,50 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< node->kind << ")"; << node->kind << ")";
}); });
IncompleteType IncompleteTypeNode::make(TypeVarNode::Kind kind) { GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) {
NodePtr<GlobalTypeVarNode> n = make_node<GlobalTypeVarNode>();
n->var = tvm::Var(name);
n->kind = std::move(kind);
return GlobalTypeVar(n);
}
TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
TVM_REGISTER_API("relay._make.GlobalTypeVar")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[1];
*ret = GlobalTypeVarNode::make(args[0], static_cast<Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const GlobalTypeVarNode *node,
tvm::IRPrinter *p) {
p->stream << "GlobalTypeVarNode(" << node->var->name_hint << ", "
<< node->kind << ")";
});
TypeCall TypeCallNode::make(Type func, tvm::Array<Type> args) {
NodePtr<TypeCallNode> n = make_node<TypeCallNode>();
n->func = std::move(func);
n->args = std::move(args);
return TypeCall(n);
}
TVM_REGISTER_NODE_TYPE(TypeCallNode);
TVM_REGISTER_API("relay._make.TypeCall")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TypeCallNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeCallNode>([](const TypeCallNode* node,
tvm::IRPrinter* p) {
p->stream << "TypeCallNode(" << node->func << ", "
<< node->args << ")";
});
IncompleteType IncompleteTypeNode::make(Kind kind) {
auto n = make_node<IncompleteTypeNode>(); auto n = make_node<IncompleteTypeNode>();
n->kind = std::move(kind); n->kind = std::move(kind);
return IncompleteType(n); return IncompleteType(n);
...@@ -82,7 +125,7 @@ TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); ...@@ -82,7 +125,7 @@ TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_API("relay._make.IncompleteType") TVM_REGISTER_API("relay._make.IncompleteType")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[0]; int kind = args[0];
*ret = IncompleteTypeNode::make(static_cast<TypeVarNode::Kind>(kind)); *ret = IncompleteTypeNode::make(static_cast<Kind>(kind));
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
......
...@@ -48,6 +48,29 @@ void TypeVisitor::VisitType_(const TypeRelationNode* op) { ...@@ -48,6 +48,29 @@ void TypeVisitor::VisitType_(const TypeRelationNode* op) {
} }
} }
void TypeVisitor::VisitType_(const GlobalTypeVarNode* op) {
}
void TypeVisitor::VisitType_(const TypeCallNode* op) {
this->VisitType(op->func);
for (const Type& t : op->args) {
this->VisitType(t);
}
}
void TypeVisitor::VisitType_(const TypeDataNode* op) {
this->VisitType(op->header);
for (const auto& v : op->type_vars) {
this->VisitType(v);
}
for (const auto& c : op->constructors) {
this->VisitType(c->belong_to);
for (const auto& t : c->inputs) {
this->VisitType(t);
}
}
}
// Type Mutator. // Type Mutator.
Array<Type> TypeMutator::MutateArray(Array<Type> arr) { Array<Type> TypeMutator::MutateArray(Array<Type> arr) {
...@@ -139,6 +162,24 @@ Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) { ...@@ -139,6 +162,24 @@ Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) {
} }
} }
Type TypeMutator::VisitType_(const GlobalTypeVarNode* op) {
return GetRef<Type>(op);
}
Type TypeMutator::VisitType_(const TypeCallNode* op) {
Type new_func = VisitType(op->func);
Array<Type> new_args = MutateArray(op->args);
if (new_args.same_as(op->args) && new_func.same_as(op->func)) {
return GetRef<TypeCall>(op);
} else {
return TypeCallNode::make(new_func, new_args);
}
}
Type TypeMutator::VisitType_(const TypeDataNode* op) {
return GetRef<Type>(op);
}
// Implements bind. // Implements bind.
class TypeBinder : public TypeMutator { class TypeBinder : public TypeMutator {
public: public:
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <tvm/node/ir_functor.h> #include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/adt.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -69,6 +70,10 @@ class TypeFunctor<R(const Type& n, Args...)> { ...@@ -69,6 +70,10 @@ class TypeFunctor<R(const Type& n, Args...)> {
virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitTypeDefault_(const Node* op, Args...) { virtual R VisitTypeDefault_(const Node* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->type_key(); LOG(FATAL) << "Do not have a default for " << op->type_key();
throw; // unreachable, written to stop compiler warning throw; // unreachable, written to stop compiler warning
...@@ -87,6 +92,9 @@ class TypeFunctor<R(const Type& n, Args...)> { ...@@ -87,6 +92,9 @@ class TypeFunctor<R(const Type& n, Args...)> {
RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
return vtable; return vtable;
} }
}; };
...@@ -103,6 +111,9 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> { ...@@ -103,6 +111,9 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> {
void VisitType_(const TupleTypeNode* op) override; void VisitType_(const TupleTypeNode* op) override;
void VisitType_(const TypeRelationNode* op) override; void VisitType_(const TypeRelationNode* op) override;
void VisitType_(const RefTypeNode* op) override; void VisitType_(const RefTypeNode* op) override;
void VisitType_(const GlobalTypeVarNode* op) override;
void VisitType_(const TypeCallNode* op) override;
void VisitType_(const TypeDataNode* op) override;
}; };
// Mutator that transform a type to another one. // Mutator that transform a type to another one.
...@@ -115,6 +126,9 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> { ...@@ -115,6 +126,9 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> {
Type VisitType_(const TupleTypeNode* op) override; Type VisitType_(const TupleTypeNode* op) override;
Type VisitType_(const TypeRelationNode* type_rel) override; Type VisitType_(const TypeRelationNode* type_rel) override;
Type VisitType_(const RefTypeNode* op) override; Type VisitType_(const RefTypeNode* op) override;
Type VisitType_(const GlobalTypeVarNode* op) override;
Type VisitType_(const TypeCallNode* op) override;
Type VisitType_(const TypeDataNode* op) override;
private: private:
Array<Type> MutateArray(Array<Type> arr); Array<Type> MutateArray(Array<Type> arr);
......
...@@ -296,6 +296,15 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -296,6 +296,15 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
ExprVisitor::VisitExpr_(op); ExprVisitor::VisitExpr_(op);
this->AddNode(op); this->AddNode(op);
} }
void VisitExpr_(const MatchNode* op) final {
this->Update(op->data, nullptr, kOpaque);
for (const Clause& c : op->clauses) {
this->Update(c->rhs, nullptr, kOpaque);
}
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
}; };
IndexedForwardGraph IndexedForwardGraph::Create( IndexedForwardGraph IndexedForwardGraph::Create(
......
...@@ -14,106 +14,160 @@ ...@@ -14,106 +14,160 @@
* contains a data type such as `int`, `float`, `uint`. * contains a data type such as `int`, `float`, `uint`.
*/ */
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/error.h>
#include "../ir/type_functor.h" #include "../ir/type_functor.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
using namespace tvm::runtime; using namespace tvm::runtime;
using Kind = TypeVarNode::Kind;
struct KindChecker : TypeVisitor { struct KindChecker : TypeFunctor<Kind(const Type&)> {
bool valid; const Module& mod;
ErrorReporter err_reporter;
KindChecker() : valid(true) {} explicit KindChecker(const Module& mod) : mod(mod), err_reporter() {}
// checks if t is an incomplete node of kind k or a type param of kind k void ReportFatalError(const Error& err) {
bool MatchKind(const Type& t, Kind k) { this->err_reporter.Report(err);
if (const IncompleteTypeNode* tv = t.as<IncompleteTypeNode>()) { this->err_reporter.RenderErrors(mod);
return tv->kind == k;
} }
if (const TypeVarNode* tp = t.as<TypeVarNode>()) { void CheckKindMatches(const Type& t, const Type& outer,
return tp->kind == k; Kind expected, const std::string& description) {
Kind k = this->VisitType(t);
if (k != expected) {
ReportFatalError(RELAY_ERROR("Incorrect kind for a " << description
<< ". Type " << t << " inside " << outer
<< " is of kind " << k
<< " but was expected to be "
<< expected));
}
}
Kind VisitType_(const IncompleteTypeNode* op) override {
return op->kind;
} }
return false; Kind VisitType_(const TypeVarNode* op) override {
return op->kind;
} }
bool IsTypeKind(const Type& t) { Kind VisitType_(const GlobalTypeVarNode* op) override {
if (MatchKind(t, Kind::kType)) { return op->kind;
return true;
} }
return t.as_derived<BaseTensorTypeNode>() || t.as<TupleTypeNode>() || t.as<FuncTypeNode>(); Kind VisitType_(const TensorTypeNode* op) override {
return Kind::kType;
} }
void VisitType_(const TupleTypeNode* op) override { Kind VisitType_(const TupleTypeNode* op) override {
// tuples should only contain normal types // tuples should only contain normal types
for (const Type& t : op->fields) { for (const Type& t : op->fields) {
this->VisitType(t); CheckKindMatches(t, GetRef<TupleType>(op), Kind::kType,
valid = valid && IsTypeKind(t); "tuple member");
if (!valid) {
return;
}
} }
return Kind::kType;
} }
void VisitType_(const FuncTypeNode* op) override { Kind VisitType_(const FuncTypeNode* op) override {
// Func types should only take normal types for arguments // Func types should only take normal types for arguments
// and only return a normal type. They should also have // and only return a normal type. They should also have
// well-formed constraints // well-formed constraints
FuncType ft = GetRef<FuncType>(op);
for (const Type& t : op->arg_types) { for (const Type& t : op->arg_types) {
this->VisitType(t); CheckKindMatches(t, ft, Kind::kType, "function type parameter");
valid = valid && IsTypeKind(t);
if (!valid) {
return;
}
} }
CheckKindMatches(ft->ret_type, ft, Kind::kType, "function return type");
for (const TypeConstraint& tc : op->type_constraints) { for (const TypeConstraint& tc : op->type_constraints) {
this->VisitType(tc); CheckKindMatches(tc, ft, Kind::kConstraint, "function type constraint");
if (!valid) {
return;
}
} }
this->VisitType(op->ret_type); return Kind::kType;
valid = valid && IsTypeKind(op->ret_type);
} }
void VisitType_(const RefTypeNode* op) override { Kind VisitType_(const RefTypeNode* op) override {
// tuples should only contain normal types // ref types should only contain normal types
this->VisitType(op->value); RefType rt = GetRef<RefType>(op);
valid = valid && IsTypeKind(op->value); CheckKindMatches(op->value, rt, Kind::kType, "ref contents");
return Kind::kType;
} }
void VisitType_(const TypeRelationNode* op) override { Kind VisitType_(const TypeRelationNode* op) override {
// arguments to type relation should be normal types // arguments to type relation should be normal types
for (const Type& t : op->args) { for (const Type& t : op->args) {
this->VisitType(t); CheckKindMatches(t, GetRef<TypeRelation>(op), Kind::kType,
valid = valid && IsTypeKind(t); "argument to type relation");
if (!valid) { }
return; return Kind::kConstraint;
}
Kind VisitType_(const TypeCallNode* op) override {
// type call func should be a global type var, args should be type
TypeCall tc = GetRef<TypeCall>(op);
const auto* gtv = op->func.as<GlobalTypeVarNode>();
if (gtv == nullptr) {
ReportFatalError(RELAY_ERROR("The callee in " << tc
<< " is not a global type var, but is " << op->func));
}
CheckKindMatches(op->func, tc, Kind::kAdtHandle, "type call function");
for (const Type& t : op->args) {
CheckKindMatches(t, tc, Kind::kType, "type call argument");
}
// finally we need to check the module to check the number of type params
auto var = GetRef<GlobalTypeVar>(gtv);
auto data = mod->LookupDef(var);
if (data->type_vars.size() != op->args.size()) {
ReportFatalError(RELAY_ERROR("Expected " << data->type_vars.size() << "arguments for " << tc
<< "; got " << op->args.size()));
}
return Kind::kType;
}
Kind VisitType_(const TypeDataNode* op) override {
// Constructors can reference the header var, but no other GlobalTypeVars.
// In theory, a TypeData could be nested, so the header scope
// should be tracked recursively, but it is unclear that we need
// to support it.
TypeData td = GetRef<TypeData>(op);
CheckKindMatches(op->header, td, Kind::kAdtHandle, "type data header");
for (const auto& var : op->type_vars) {
CheckKindMatches(var, td, Kind::kType, "ADT type var");
}
for (const auto& con : op->constructors) {
if (!con->belong_to.same_as(op->header)) {
ReportFatalError(RELAY_ERROR(con << " has header " << con->belong_to
<< " but " << op << "has header " << op->header));
}
for (const Type& t : con->inputs) {
CheckKindMatches(t, td, Kind::kType, "ADT constructor input");
} }
} }
return Kind::kTypeData;
} }
bool Check(const Type& t) { Kind Check(const Type& t) {
this->VisitType(t); return this->VisitType(t);
return valid;
} }
}; };
bool KindCheck(const Type& t, const Module& mod) { Kind KindCheck(const Type& t, const Module& mod) {
KindChecker kc; KindChecker kc(mod);
return kc.Check(t); return kc.Check(t);
} }
TVM_REGISTER_API("relay._ir_pass.check_kind") TVM_REGISTER_API("relay._ir_pass.check_kind")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) { if (args.size() == 1) {
*ret = KindCheck(args[0], ModuleNode::make({})); *ret = KindCheck(args[0], ModuleNode::make({}, {}));
} else { } else {
*ret = KindCheck(args[0], args[1]); *ret = KindCheck(args[0], args[1]);
} }
......
...@@ -62,7 +62,7 @@ class LetList { ...@@ -62,7 +62,7 @@ class LetList {
* \return a Var that hold the inserted expr. * \return a Var that hold the inserted expr.
*/ */
Var Push(Expr expr) { Var Push(Expr expr) {
return Push(IncompleteTypeNode::make(TypeVarNode::kType), expr); return Push(IncompleteTypeNode::make(Kind::kType), expr);
} }
/*! /*!
......
...@@ -274,7 +274,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> { ...@@ -274,7 +274,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
} }
Expr VisitExpr(const Expr& e) { Expr VisitExpr(const Expr& e) {
Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(TypeVarNode::kType)); Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType));
return this->VisitExpr(e, v); return this->VisitExpr(e, v);
} }
......
...@@ -189,6 +189,20 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -189,6 +189,20 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
return RefTypeNode::make(Unify(op->value, rtn->value)); return RefTypeNode::make(Unify(op->value, rtn->value));
} }
Type VisitType_(const TypeCallNode* op, const Type& tn) override {
const auto* tcn = tn.as<TypeCallNode>();
if (!tcn || tcn->args.size() != op->args.size()) {
return Type();
}
Type func = Unify(op->func, tcn->func);
tvm::Array<Type> args;
for (size_t i = 0; i < op->args.size(); i++) {
args.push_back(Unify(op->args[i], tcn->args[i]));
}
return TypeCallNode::make(func, args);
}
private: private:
TypeSolver* solver_; TypeSolver* solver_;
}; };
...@@ -266,6 +280,16 @@ class TypeSolver::Propagator : public TypeFunctor<void(const Type&)> { ...@@ -266,6 +280,16 @@ class TypeSolver::Propagator : public TypeFunctor<void(const Type&)> {
} }
} }
void VisitType_(const TypeCallNode* op) override {
TypeCall tc = GetRef<TypeCall>(op);
UpdateRelSet(tc);
Propagate(tc->func);
for (auto arg : tc->args) {
Propagate(arg);
}
}
private: private:
TypeSolver* solver_; TypeSolver* solver_;
const std::unordered_set<RelationNode*>* rels_; const std::unordered_set<RelationNode*>* rels_;
...@@ -494,7 +518,7 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver") ...@@ -494,7 +518,7 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver")
} else if (name == "AddConstraint") { } else if (name == "AddConstraint") {
return TypedPackedFunc<void(TypeConstraint)>([solver](TypeConstraint c) { return TypedPackedFunc<void(TypeConstraint)>([solver](TypeConstraint c) {
Expr e = VarNode::make("dummy_var", Expr e = VarNode::make("dummy_var",
IncompleteTypeNode::make(TypeVarNode::Kind::kType)); IncompleteTypeNode::make(Kind::kType));
return solver->AddConstraint(c, e); return solver->AddConstraint(c, e);
}); });
} else { } else {
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
*/ */
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "../ir/type_functor.h" #include "../ir/type_functor.h"
namespace tvm { namespace tvm {
...@@ -51,6 +52,8 @@ class TypeVarTVisitor : public TypeVisitor { ...@@ -51,6 +52,8 @@ class TypeVarTVisitor : public TypeVisitor {
class TypeVarEVisitor : private ExprVisitor { class TypeVarEVisitor : private ExprVisitor {
public: public:
explicit TypeVarEVisitor(const Module& mod) : mod_(mod) {}
Array<TypeVar> CollectFree() { Array<TypeVar> CollectFree() {
Array<TypeVar> ret; Array<TypeVar> ret;
for (const auto& v : type_vars_.data) { for (const auto& v : type_vars_.data) {
...@@ -115,6 +118,16 @@ class TypeVarEVisitor : private ExprVisitor { ...@@ -115,6 +118,16 @@ class TypeVarEVisitor : private ExprVisitor {
ExprVisitor::VisitExpr_(f); ExprVisitor::VisitExpr_(f);
} }
void VisitExpr_(const ConstructorNode* cn) final {
// for constructors, type vars will be bound in the module
auto data = mod_->LookupDef(cn->belong_to);
for (const auto& tv : data->type_vars) {
type_vars_.Insert(tv);
bound_type_vars_.Insert(tv);
}
ExprVisitor::VisitExpr_(cn);
}
void VisitType(const Type& t) final { void VisitType(const Type& t) final {
TypeVarTVisitor(&type_vars_, &bound_type_vars_) TypeVarTVisitor(&type_vars_, &bound_type_vars_)
.VisitType(t); .VisitType(t);
...@@ -123,9 +136,10 @@ class TypeVarEVisitor : private ExprVisitor { ...@@ -123,9 +136,10 @@ class TypeVarEVisitor : private ExprVisitor {
private: private:
InsertionSet<TypeVar> type_vars_; InsertionSet<TypeVar> type_vars_;
InsertionSet<TypeVar> bound_type_vars_; InsertionSet<TypeVar> bound_type_vars_;
const Module& mod_;
}; };
class VarVisitor : protected ExprVisitor { class VarVisitor : protected ExprVisitor, protected PatternVisitor {
public: public:
Array<Var> Free(const Expr& expr) { Array<Var> Free(const Expr& expr) {
this->VisitExpr(expr); this->VisitExpr(expr);
...@@ -178,33 +192,41 @@ class VarVisitor : protected ExprVisitor { ...@@ -178,33 +192,41 @@ class VarVisitor : protected ExprVisitor {
VisitExpr(op->body); VisitExpr(op->body);
} }
void VisitPattern(const Pattern& p) final {
PatternVisitor::VisitPattern(p);
}
void VisitPattern_(const PatternVarNode* op) final {
MarkBounded(op->var);
}
private: private:
InsertionSet<Var> vars_; InsertionSet<Var> vars_;
InsertionSet<Var> bound_vars_; InsertionSet<Var> bound_vars_;
}; };
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr) { tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod) {
return TypeVarEVisitor().Free(expr); return TypeVarEVisitor(mod).Free(expr);
} }
tvm::Array<TypeVar> FreeTypeVars(const Type& type) { tvm::Array<TypeVar> FreeTypeVars(const Type& type, const Module& mod) {
return TypeVarEVisitor().Free(type); return TypeVarEVisitor(mod).Free(type);
} }
tvm::Array<TypeVar> BoundTypeVars(const Expr& expr) { tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod) {
return TypeVarEVisitor().Bound(expr); return TypeVarEVisitor(mod).Bound(expr);
} }
tvm::Array<TypeVar> BoundTypeVars(const Type& type) { tvm::Array<TypeVar> BoundTypeVars(const Type& type, const Module& mod) {
return TypeVarEVisitor().Bound(type); return TypeVarEVisitor(mod).Bound(type);
} }
tvm::Array<TypeVar> AllTypeVars(const Expr& expr) { tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod) {
return TypeVarEVisitor().All(expr); return TypeVarEVisitor(mod).All(expr);
} }
tvm::Array<TypeVar> AllTypeVars(const Type& type) { tvm::Array<TypeVar> AllTypeVars(const Type& type, const Module& mod) {
return TypeVarEVisitor().All(type); return TypeVarEVisitor(mod).All(type);
} }
tvm::Array<Var> FreeVars(const Expr& expr) { tvm::Array<Var> FreeVars(const Expr& expr) {
...@@ -237,30 +259,33 @@ TVM_REGISTER_API("relay._ir_pass.all_vars") ...@@ -237,30 +259,33 @@ TVM_REGISTER_API("relay._ir_pass.all_vars")
TVM_REGISTER_API("relay._ir_pass.free_type_vars") TVM_REGISTER_API("relay._ir_pass.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0]; NodeRef x = args[0];
Module mod = args[1];
if (x.as_derived<TypeNode>()) { if (x.as_derived<TypeNode>()) {
*ret = FreeTypeVars(Downcast<Type>(x)); *ret = FreeTypeVars(Downcast<Type>(x), mod);
} else { } else {
*ret = FreeTypeVars(Downcast<Expr>(x)); *ret = FreeTypeVars(Downcast<Expr>(x), mod);
} }
}); });
TVM_REGISTER_API("relay._ir_pass.bound_type_vars") TVM_REGISTER_API("relay._ir_pass.bound_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0]; NodeRef x = args[0];
Module mod = args[1];
if (x.as_derived<TypeNode>()) { if (x.as_derived<TypeNode>()) {
*ret = BoundTypeVars(Downcast<Type>(x)); *ret = BoundTypeVars(Downcast<Type>(x), mod);
} else { } else {
*ret = BoundTypeVars(Downcast<Expr>(x)); *ret = BoundTypeVars(Downcast<Expr>(x), mod);
} }
}); });
TVM_REGISTER_API("relay._ir_pass.all_type_vars") TVM_REGISTER_API("relay._ir_pass.all_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0]; NodeRef x = args[0];
Module mod = args[1];
if (x.as_derived<TypeNode>()) { if (x.as_derived<TypeNode>()) {
*ret = AllTypeVars(Downcast<Type>(x)); *ret = AllTypeVars(Downcast<Type>(x), mod);
} else { } else {
*ret = AllTypeVars(Downcast<Expr>(x)); *ret = AllTypeVars(Downcast<Expr>(x), mod);
} }
}); });
......
...@@ -13,7 +13,10 @@ TEST(Relay, SelfReference) { ...@@ -13,7 +13,10 @@ TEST(Relay, SelfReference) {
auto y = relay::VarNode::make("y", tensor_type); auto y = relay::VarNode::make("y", tensor_type);
auto call = relay::CallNode::make(f, Array<relay::Expr>{ y }); auto call = relay::CallNode::make(f, Array<relay::Expr>{ y });
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {}); auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map<relay::GlobalVar, relay::Function>{})); auto empty_module =
relay::ModuleNode::make(Map<relay::GlobalVar, relay::Function>{},
Map<relay::GlobalTypeVar, relay::TypeData>{});
auto type_fx = relay::InferType(fx, empty_module);
auto expected = relay::FuncTypeNode::make(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {}); auto expected = relay::FuncTypeNode::make(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {});
CHECK(AlphaEqual(type_fx->checked_type(), expected)); CHECK(AlphaEqual(type_fx->checked_type(), expected));
......
...@@ -171,6 +171,29 @@ def test_type_relation_alpha_equal(): ...@@ -171,6 +171,29 @@ def test_type_relation_alpha_equal():
assert bigger != diff_num_inputs assert bigger != diff_num_inputs
def test_type_call_alpha_equal():
h1 = relay.GlobalTypeVar("h1")
h2 = relay.GlobalTypeVar("h2")
t1 = relay.TensorType((1, 2), "float32")
t2 = relay.TensorType((1, 2, 3), "float32")
t3 = relay.TensorType((1, 2, 3, 4), "float32")
t4 = relay.TensorType((), "float32")
tc = relay.TypeCall(h1, [t1, t2, t3])
same = relay.TypeCall(h1, [t1, t2, t3])
different_func = relay.TypeCall(h2, [t1, t2, t3])
different_arg = relay.TypeCall(h1, [t1, t2, t4])
fewer_args = relay.TypeCall(h1, [t1, t2])
more_args = relay.TypeCall(h1, [t1, t2, t3, t4])
different_order_args = relay.TypeCall(h1, [t3, t2, t1])
assert tc == same
assert tc != different_func
assert tc != fewer_args
assert tc != more_args
assert tc != different_order_args
def test_constant_alpha_equal(): def test_constant_alpha_equal():
x = relay.const(1) x = relay.const(1)
...@@ -453,6 +476,79 @@ def test_if_alpha_equal(): ...@@ -453,6 +476,79 @@ def test_if_alpha_equal():
assert not alpha_equal(if_sample, different_false) assert not alpha_equal(if_sample, different_false)
def test_constructor_alpha_equal():
# smoke test: it should be pointer equality
mod = relay.Module()
p = relay.prelude.Prelude(mod)
assert alpha_equal(p.nil, p.nil)
assert alpha_equal(p.cons, p.cons)
assert not alpha_equal(p.nil, p.cons)
def test_match_alpha_equal():
mod = relay.Module()
p = relay.prelude.Prelude(mod)
x = relay.Var('x')
y = relay.Var('y')
nil_case = relay.Clause(relay.PatternConstructor(p.nil), p.nil())
cons_case = relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternVar(x),
relay.PatternVar(y)]),
p.cons(x, y))
z = relay.Var('z')
a = relay.Var('a')
equivalent_cons = relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternVar(z),
relay.PatternVar(a)]),
p.cons(z, a))
data = p.cons(p.z(), p.cons(p.z(), p.nil()))
match = relay.Match(data, [nil_case, cons_case])
equivalent = relay.Match(data, [nil_case, equivalent_cons])
empty = relay.Match(data, [])
no_cons = relay.Match(data, [nil_case])
no_nil = relay.Match(data, [cons_case])
different_data = relay.Match(p.nil(), [nil_case, cons_case])
different_order = relay.Match(data, [cons_case, nil_case])
different_nil = relay.Match(data, [
relay.Clause(relay.PatternConstructor(p.nil), p.cons(p.nil(), p.nil())),
cons_case
])
different_cons = relay.Match(data, [
nil_case,
relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternWildcard(),
relay.PatternWildcard()]),
p.nil())
])
another_case = relay.Match(data, [
nil_case,
cons_case,
relay.Clause(relay.PatternWildcard(), p.nil())
])
wrong_constructors = relay.Match(data, [
relay.Clause(relay.PatternConstructor(p.z), p.nil()),
relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]),
p.cons(x, p.nil()))
])
assert alpha_equal(match, match)
assert alpha_equal(match, equivalent)
assert not alpha_equal(match, no_cons)
assert not alpha_equal(match, no_nil)
assert not alpha_equal(match, empty)
assert not alpha_equal(match, different_data)
assert not alpha_equal(match, different_order)
assert not alpha_equal(match, different_nil)
assert not alpha_equal(match, different_cons)
assert not alpha_equal(match, another_case)
assert not alpha_equal(match, wrong_constructors)
def test_op_alpha_equal(): def test_op_alpha_equal():
# only checks names # only checks names
op1 = relay.op.get("add") op1 = relay.op.get("add")
...@@ -491,6 +587,7 @@ if __name__ == "__main__": ...@@ -491,6 +587,7 @@ if __name__ == "__main__":
test_func_type_alpha_equal() test_func_type_alpha_equal()
test_tuple_type_alpha_equal() test_tuple_type_alpha_equal()
test_type_relation_alpha_equal() test_type_relation_alpha_equal()
test_type_call_alpha_equal()
test_constant_alpha_equal() test_constant_alpha_equal()
test_global_var_alpha_equal() test_global_var_alpha_equal()
test_tuple_alpha_equal() test_tuple_alpha_equal()
...@@ -499,6 +596,8 @@ if __name__ == "__main__": ...@@ -499,6 +596,8 @@ if __name__ == "__main__":
test_call_alpha_equal() test_call_alpha_equal()
test_let_alpha_equal() test_let_alpha_equal()
test_if_alpha_equal() test_if_alpha_equal()
test_constructor_alpha_equal()
test_match_alpha_equal()
test_op_alpha_equal() test_op_alpha_equal()
test_var_alpha_equal() test_var_alpha_equal()
test_graph_equal() test_graph_equal()
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.ir_pass import check_kind from tvm.relay.ir_pass import check_kind
from nose.tools import raises
def test_typevar_kind():
# returns the same kind
tp1 = relay.TypeVar('tp1', relay.Kind.Type)
tp2 = relay.TypeVar('tp2', relay.Kind.Shape)
tp3 = relay.TypeVar('tp3', relay.Kind.Constraint)
assert check_kind(tp1) == relay.Kind.Type
assert check_kind(tp2) == relay.Kind.Shape
assert check_kind(tp3) == relay.Kind.Constraint
def test_tuple_kind(): def test_tuple_kind():
# only contain type kinds # only contain type kinds
...@@ -10,7 +23,7 @@ def test_tuple_kind(): ...@@ -10,7 +23,7 @@ def test_tuple_kind():
fields = tvm.convert([tp, tf, tt]) fields = tvm.convert([tp, tf, tt])
tup_ty = relay.TupleType(fields) tup_ty = relay.TupleType(fields)
assert check_kind(tup_ty) assert check_kind(tup_ty) == relay.Kind.Type
def test_func_kind(): def test_func_kind():
...@@ -30,7 +43,20 @@ def test_func_kind(): ...@@ -30,7 +43,20 @@ def test_func_kind():
ret_type = relay.TupleType(tvm.convert([tp2, tensor_type])) ret_type = relay.TupleType(tvm.convert([tp2, tensor_type]))
tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
assert check_kind(tf) assert check_kind(tf) == relay.Kind.Type
def test_ref_kind():
# only contain type kinds
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
ft = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([]))
rt1 = relay.RefType(tt)
assert check_kind(rt1) == relay.Kind.Type
rt2 = relay.RefType(ft)
assert check_kind(rt2) == relay.Kind.Type
rt3 = relay.RefType(relay.TupleType([rt1, rt2]))
assert check_kind(rt3) == relay.Kind.Type
def test_relation_kind(): def test_relation_kind():
...@@ -41,9 +67,35 @@ def test_relation_kind(): ...@@ -41,9 +67,35 @@ def test_relation_kind():
args = tvm.convert([tf, tt, tp]) args = tvm.convert([tf, tt, tp])
tr = relay.TypeRelation(None, args, 2, None) tr = relay.TypeRelation(None, args, 2, None)
assert check_kind(tr) assert check_kind(tr) == relay.Kind.Constraint
def test_global_typevar_kind():
v1 = relay.GlobalTypeVar('gtv1', relay.Kind.AdtHandle)
v2 = relay.GlobalTypeVar('gtv2', relay.Kind.Type)
assert check_kind(v1) == relay.Kind.AdtHandle
assert check_kind(v2) == relay.Kind.Type
def test_typecall_kind():
gtv = relay.GlobalTypeVar('gtv')
mod = relay.Module()
data = relay.TypeData(gtv, [], [])
mod[gtv] = data
empty_call = relay.TypeCall(gtv, [])
assert check_kind(empty_call, mod) == relay.Kind.Type
new_mod = relay.Module()
tv = relay.TypeVar('tv')
new_data = relay.TypeData(gtv, [tv], [])
new_mod[gtv] = new_data
call = relay.TypeCall(gtv, [relay.TupleType([])])
assert check_kind(call, new_mod) == relay.Kind.Type
@raises(tvm._ffi.base.TVMError)
def test_invalid_tuple_kind(): def test_invalid_tuple_kind():
tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType)
...@@ -51,9 +103,10 @@ def test_invalid_tuple_kind(): ...@@ -51,9 +103,10 @@ def test_invalid_tuple_kind():
fields = tvm.convert([tp1, tp2, tp3]) fields = tvm.convert([tp1, tp2, tp3])
tup_ty = relay.TupleType(fields) tup_ty = relay.TupleType(fields)
assert not check_kind(tup_ty) check_kind(tup_ty)
@raises(tvm._ffi.base.TVMError)
def test_invalid_func_kind(): def test_invalid_func_kind():
tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType)
...@@ -65,51 +118,98 @@ def test_invalid_func_kind(): ...@@ -65,51 +118,98 @@ def test_invalid_func_kind():
ret_type = tp3 ret_type = tp3
tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
assert not check_kind(tf) check_kind(tf)
@raises(tvm._ffi.base.TVMError)
def test_invalid_ref_kind():
tp = relay.TypeVar('tp', relay.Kind.Shape)
rt = relay.RefType(tp)
check_kind(rt)
@raises(tvm._ffi.base.TVMError)
def test_invalid_relation_kind(): def test_invalid_relation_kind():
tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType)
tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar)
args = tvm.convert([tp1, tp2, tp3]) args = tvm.convert([tp1, tp2, tp3])
tr = relay.TypeRelation(None, args, 2, None) func = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
assert not check_kind(tr) tr = relay.TypeRelation(func, args, 2, None)
check_kind(tr)
@raises(tvm._ffi.base.TVMError)
def test_typecall_invalid_callee():
# global type var must be an ADT handle
gtv = relay.GlobalTypeVar('v1', relay.Kind.Type)
check_kind(relay.TypeCall(gtv, []))
@raises(tvm._ffi.base.TVMError)
def test_typecall_invalid_args():
# args must all be type kind
mod = relay.Module()
gtv = relay.GlobalTypeVar('v1')
data = relay.TypeData(gtv, [], [])
mod[gtv] = data
check_kind(relay.TypeCall(gtv, [data]))
@raises(tvm._ffi.base.TVMError)
def test_typecall_invalid_num_args():
mod = relay.Module()
gtv = relay.GlobalTypeVar('v1')
tv = relay.TypeVar('tv')
data = relay.TypeData(gtv, [tv], [])
mod[gtv] = data
check_kind(relay.TypeCall(gtv, []))
@raises(tvm._ffi.base.TVMError)
def test_func_with_invalid_ret_type(): def test_func_with_invalid_ret_type():
tp1 = relay.TypeVar('tp1', relay.Kind.Type) tp1 = relay.TypeVar('tp1', relay.Kind.Type)
tp2 = relay.TypeVar('tp2', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.Shape)
tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([]))
check_kind(tf)
@raises(tvm._ffi.base.TVMError)
def test_func_with_invalid_arg_types(): def test_func_with_invalid_arg_types():
tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tp2 = relay.TypeVar('tp2', relay.Kind.Type) tp2 = relay.TypeVar('tp2', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([])) tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([]))
check_kind(tf)
@raises(tvm._ffi.base.TVMError)
def test_func_with_invalid_tuple(): def test_func_with_invalid_tuple():
tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
ret_type = relay.TupleType(tvm.convert([tp1, tp1, tp1])) ret_type = relay.TupleType(tvm.convert([tp1, tp1, tp1]))
tf = relay.FuncType(tvm.convert([]), ret_type, tvm.convert([tp1]), tvm.convert([])) tf = relay.FuncType(tvm.convert([]), ret_type, tvm.convert([tp1]), tvm.convert([]))
assert not check_kind(tf) check_kind(tf)
@raises(tvm._ffi.base.TVMError)
def test_func_with_invalid_relation(): def test_func_with_invalid_relation():
tp1 = relay.TypeVar('tp1', relay.Kind.Type) tp1 = relay.TypeVar('tp1', relay.Kind.Type)
tp2 = relay.TypeVar('tp2', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.Shape)
tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar)
tr = relay.TypeRelation(None, tvm.convert([tp2, tp3]), 1, None) func = tvm.get_env_func("tvm.relay.type_relation.Identity")
tr = relay.TypeRelation(func, tvm.convert([tp2, tp3]), 1, None)
tf = relay.FuncType(tvm.convert([tp1]), tp1, tvm.convert([tp1, tp2, tp3]), tvm.convert([tr])) tf = relay.FuncType(tvm.convert([tp1]), tp1, tvm.convert([tp1, tp2, tp3]), tvm.convert([tr]))
assert not check_kind(tf) check_kind(tf)
@raises(tvm._ffi.base.TVMError)
def test_tuple_with_invalid_func(): def test_tuple_with_invalid_func():
tensor_type = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tensor_type = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
...@@ -117,16 +217,23 @@ def test_tuple_with_invalid_func(): ...@@ -117,16 +217,23 @@ def test_tuple_with_invalid_func():
tf = relay.FuncType(tvm.convert([]), tp1, tvm.convert([tp1]), tvm.convert([])) tf = relay.FuncType(tvm.convert([]), tp1, tvm.convert([tp1]), tvm.convert([]))
tup_ty = relay.TupleType(tvm.convert([tensor_type, tf])) tup_ty = relay.TupleType(tvm.convert([tensor_type, tf]))
assert not check_kind(tup_ty) check_kind(tup_ty)
if __name__ == "__main__": if __name__ == "__main__":
test_tuple_kind() test_tuple_kind()
test_func_kind() test_func_kind()
test_ref_kind()
test_relation_kind() test_relation_kind()
test_global_typevar_kind()
test_typecall_kind()
test_invalid_tuple_kind() test_invalid_tuple_kind()
test_invalid_func_kind() test_invalid_func_kind()
test_invalid_ref_kind()
test_invalid_relation_kind() test_invalid_relation_kind()
test_typecall_invalid_callee()
test_typecall_invalid_args()
test_typecall_invalid_num_args()
test_func_with_invalid_ret_type() test_func_with_invalid_ret_type()
test_func_with_invalid_arg_types() test_func_with_invalid_arg_types()
test_func_with_invalid_tuple() test_func_with_invalid_tuple()
......
...@@ -65,6 +65,40 @@ def test_bound_vars(): ...@@ -65,6 +65,40 @@ def test_bound_vars():
assert_vars_match(bound_vars(f2), [x, y]) assert_vars_match(bound_vars(f2), [x, y])
def test_match_vars():
mod = relay.Module()
p = relay.prelude.Prelude(mod)
x = relay.Var('x')
y = relay.Var('y')
z = relay.Var('z')
match1 = relay.Match(p.nil(), [
relay.Clause(relay.PatternConstructor(p.nil), z),
relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternVar(x),
relay.PatternVar(y)]),
p.cons(x, y))
])
match2 = relay.Match(p.nil(), [
relay.Clause(relay.PatternConstructor(p.cons, [
relay.PatternWildcard(),
relay.PatternVar(x)
]),
y),
relay.Clause(relay.PatternWildcard(), z)
])
assert_vars_match(bound_vars(match1), [x, y])
assert_vars_match(free_vars(match1), [z])
assert_vars_match(all_vars(match1), [z, x, y])
assert_vars_match(bound_vars(match2), [x])
assert_vars_match(free_vars(match2), [y, z])
assert_vars_match(all_vars(match2), [x, y, z])
def test_bound_type_vars(): def test_bound_type_vars():
a = relay.TypeVar("a") a = relay.TypeVar("a")
b = relay.TypeVar("b") b = relay.TypeVar("b")
......
...@@ -17,6 +17,16 @@ def assert_has_type(expr, typ, mod=relay.module.Module({})): ...@@ -17,6 +17,16 @@ def assert_has_type(expr, typ, mod=relay.module.Module({})):
checked_type, typ)) checked_type, typ))
# initializes simple ADT for tests
def initialize_box_adt(mod):
box = relay.GlobalTypeVar('box')
tv = relay.TypeVar('tv')
constructor = relay.Constructor('constructor', [tv], box)
data = relay.TypeData(box, [tv], [constructor])
mod[box] = data
return (box, constructor)
def test_monomorphic_let(): def test_monomorphic_let():
"Program: let x = 1; return x" "Program: let x = 1; return x"
sb = relay.ScopeBuilder() sb = relay.ScopeBuilder()
...@@ -190,6 +200,69 @@ def test_equal(): ...@@ -190,6 +200,69 @@ def test_equal():
assert ft.checked_type == relay.FuncType([relay.scalar_type('int32')], relay.scalar_type('bool')) assert ft.checked_type == relay.FuncType([relay.scalar_type('int32')], relay.scalar_type('bool'))
def test_constructor_type():
mod = relay.Module()
box, constructor = initialize_box_adt(mod)
a = relay.TypeVar('a')
x = relay.Var('x', a)
ct = relay.ir_pass.infer_type(
relay.Function([x], constructor(x), box(a), [a]), mod)
expected = relay.FuncType([a], box(a), [a])
assert ct.checked_type == expected
def test_constructor_call():
mod = relay.Module()
box, constructor = initialize_box_adt(mod)
box_unit = constructor(relay.Tuple([]))
box_constant = constructor(relay.const(0, 'float32'))
ut = relay.ir_pass.infer_type(box_unit, mod)
ct = relay.ir_pass.infer_type(box_constant, mod)
assert ut.checked_type == box(relay.TupleType([]))
assert ct.checked_type == box(relay.TensorType((), 'float32'))
def test_adt_match():
mod = relay.Module()
box, constructor = initialize_box_adt(mod)
v = relay.Var('v', relay.TensorType((), 'float32'))
match = relay.Match(constructor(relay.const(0, 'float32')),
[relay.Clause(
relay.PatternConstructor(constructor,
[relay.PatternVar(v)]),
relay.Tuple([])),
# redundant but shouldn't matter to typechecking
relay.Clause(relay.PatternWildcard(),
relay.Tuple([]))])
mt = relay.ir_pass.infer_type(match, mod)
assert mt.checked_type == relay.TupleType([])
def test_adt_match_type_annotations():
mod = relay.Module()
box, constructor = initialize_box_adt(mod)
# the only type annotation is inside the match pattern var
# but that should be enough info
tt = relay.TensorType((2, 2), 'float32')
x = relay.Var('x')
mv = relay.Var('mv', tt)
match = relay.Match(constructor(x),
[relay.Clause(
relay.PatternConstructor(constructor,
[relay.PatternVar(mv)]),
relay.Tuple([]))])
func = relay.Function([x], match)
ft = relay.ir_pass.infer_type(func, mod)
assert ft.checked_type == relay.FuncType([tt], relay.TupleType([]))
if __name__ == "__main__": if __name__ == "__main__":
test_free_expr() test_free_expr()
test_dual_op() test_dual_op()
...@@ -205,3 +278,6 @@ if __name__ == "__main__": ...@@ -205,3 +278,6 @@ if __name__ == "__main__":
test_global_var_recursion() test_global_var_recursion()
test_equal() test_equal()
test_ref() test_ref()
test_constructor_type()
test_constructor_call()
test_adt_match()
...@@ -62,6 +62,30 @@ def test_unify_tuple(): ...@@ -62,6 +62,30 @@ def test_unify_tuple():
assert unified == tup2 assert unified == tup2
def test_unify_global_type_var():
# should only be able to unify if they're the same
solver = make_solver()
gtv = relay.GlobalTypeVar('gtv')
unified = solver.Unify(gtv, gtv)
assert unified == gtv
def test_unify_typecall():
solver = make_solver()
gtv = relay.GlobalTypeVar('gtv')
# yeah, typecalls are shaped like tuples so the same
# tests work out
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.TensorType((10, 20), "float32")
tc1 = relay.ty.TypeCall(gtv, [t1, t2])
tc2 = relay.ty.TypeCall(gtv, [t3, t3])
unified = solver.Unify(tc1, tc2)
assert unified == tc2
def test_unify_functype(): def test_unify_functype():
solver = make_solver() solver = make_solver()
t1 = relay.ty.IncompleteType() t1 = relay.ty.IncompleteType()
...@@ -205,10 +229,49 @@ def test_bad_recursive_unification(): ...@@ -205,10 +229,49 @@ def test_bad_recursive_unification():
solver.Unify(t1, relay.ty.TupleType([t1, t1])) solver.Unify(t1, relay.ty.TupleType([t1, t1]))
@raises(tvm._ffi.base.TVMError)
def test_unify_invalid_global_typevars():
solver = make_solver()
gtv1 = relay.GlobalTypeVar('gtv1')
gtv2 = relay.GlobalTypeVar('gtv2')
solver.Unify(gtv1, gtv2)
@raises(tvm._ffi.base.TVMError)
def test_incompatible_typecall_var_unification():
solver = make_solver()
gtv1 = relay.GlobalTypeVar('gtv1')
gtv2 = relay.GlobalTypeVar('gtv2')
t1 = relay.IncompleteType()
t2 = relay.IncompleteType()
tc1 = relay.TypeCall(gtv1, [t1])
tc2 = relay.TypeCall(gtv2, [t2])
solver.Unify(tc1, tc2)
@raises(tvm._ffi.base.TVMError)
def test_incompatible_typecall_args_unification():
solver = make_solver()
gtv = relay.GlobalTypeVar('gtv1')
t1 = relay.IncompleteType()
t2 = relay.IncompleteType()
tensor1 = relay.TensorType((1, 2, 3), "float32")
tensor2 = relay.TensorType((2, 3), "float32")
tensor3 = relay.TensorType((3,), "float32")
tc1 = relay.TypeCall(gtv, [relay.TupleType([t1, t1]), t2])
tc2 = relay.TypeCall(gtv, [relay.TupleType([tensor1, tensor2]), tensor3])
solver.Unify(tc1, tc2)
if __name__ == "__main__": if __name__ == "__main__":
test_bcast() test_bcast()
test_backward_solving() test_backward_solving()
test_unify_tuple() test_unify_tuple()
test_unify_typecall()
test_unify_functype() test_unify_functype()
test_recursive_unify() test_recursive_unify()
test_unify_vars_under_tuples() test_unify_vars_under_tuples()
...@@ -216,3 +279,5 @@ if __name__ == "__main__": ...@@ -216,3 +279,5 @@ if __name__ == "__main__":
test_backward_solving_after_child_update() test_backward_solving_after_child_update()
test_incompatible_tuple_unification() test_incompatible_tuple_unification()
test_bad_recursive_unification() test_bad_recursive_unification()
test_incompatible_typecall_var_unification()
test_incompatible_typecall_args_unification()
from tvm import relay
from tvm.relay.ir_pass import infer_type
def test_dup_type():
a = relay.TypeVar("a")
av = relay.Var("av", a)
make_id = relay.Function([av], relay.Tuple([av, av]), None, [a])
t = relay.scalar_type("float32")
b = relay.Var("b", t)
assert relay.ir_pass.infer_type(make_id(b)).checked_type == relay.TupleType([t, t])
def test_id_type():
mod = relay.Module()
id_type = relay.GlobalTypeVar("id")
a = relay.TypeVar("a")
mod[id_type] = relay.TypeData(id_type, [a], [])
b = relay.TypeVar("b")
make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b]))
t = relay.scalar_type("float32")
b = relay.Var("b", t)
assert relay.ir_pass.infer_type(make_id(b), mod).checked_type == id_type(t)
if __name__ == "__main__":
test_dup_type()
test_id_type()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment