Commit 2ae3124f by Steven S. Lyubomirsky Committed by ziheng

[Relay] Algebraic data types (#2442)

* First pass on ADTs

* Add doc string for tag field

* Visit constructors in TypeVisitor for TypeData

* Add to description of type call

* Add type call to type solving and unification

* Make type mutator for typecall consistent with others (only create new node if there's a change)

* Ensure kindchecking can handle type calls and typedata

* Fix bad nesting in module constructor

* Correctly construct call in typecall test

* Add call override for ordinary vars (do we want this?)

* Remove generalization hack from type inference because it was breaking ADT constructors

* Check that there are no free type vars in exprs after inferring type

* Free var checks need module because of ADT constructors

* Typecall test can't have unbound type var, make it global

* Uncomment tmap test and remove comments about failing to infer ret type; those work now

* Put in dummy visits for ADTs in graph runtime codegen to placate pylint

* Fix Relay type infer test module constructor

* Mark override for TypeCallNode in type solver

* Ensure free vars check treats patern vars as bound

* Run interpreter in more ADT test cases

* Refactor kind check to return the kind, like typechecking

* Fix invalid typecall in test

* Add kind check to type inference, do not use nulls in func_type_annotation()!

* Redundant whitespace

* Make TypeData a separate kind

* Make ADT handles a separate kind too, document calling convention better

* Remove nats and tree from prelude, move to test, document prelude

* Restore and document nat and tree to prelude, add more tree tests

* Add alpha equality tests for match cases, fix variable binding bug

* Add more kind check tests for ADTs

* Add more tests for finding free or bound vars in match exprs

* Add unification tests for type call

* Update main() for alpha equality tests

* Add simple type inference test cases for match exprs and ADT constructors

* Add more ADT interpreter tests

* Allow incomplete types when typechecking match cases

* Type inference for pattern vars should use the type annotation if it's there

* Two more specific test cases for ADT matching

* Add option ADT to prelude

* Fix broken reference to kind enum

* Fix rebase snags

* Do not attach checked types to constructors

* More docstrings for module fields

* Use proper wrapper for indexing into module type data

* checked_type for constructors is not populated

* Expand type call docstring

* Rename PatternConstructor con field

* Use error reporter for pattern constructor case

* Condense error reporting in kind check, use error reporter

* Expand docstrings and rename ADT fields

* Rename 'option' ADT to 'optional' for consistency with Python

* Add various list iterators and utility functions to prelude

* Add smoke tests for new iterators in prelude

* Add concat to prelude

* Add smoke test for concat

* Correct docstrings in prelude

* Ensure that type defs are written in module initialization

* Various requested renamings

* Correct rebase snags

* Add kind check tests for ref types

* Update the main() for kind checking tests
parent c7165427
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/adt.h
* \brief Algebraic data types for Relay
*/
#ifndef TVM_RELAY_ADT_H_
#define TVM_RELAY_ADT_H_
#include <tvm/attrs.h>
#include <string>
#include <functional>
#include "./base.h"
#include "./type.h"
#include "./expr.h"
namespace tvm {
namespace relay {
/*! \brief Base type for declaring relay pattern. */
class PatternNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Pattern";
TVM_DECLARE_BASE_NODE_INFO(PatternNode, Node);
};
/*!
* \brief Pattern is the base type for an ADT match pattern in Relay.
*
* Given an ADT value, a pattern might accept it and bind the pattern variable to some value
* (typically a subnode of the input or the input). Otherwise, the pattern rejects the value.
*
* ADT pattern matching thus takes a list of values and binds to the first that accepts the value.
*/
class Pattern : public NodeRef {
public:
Pattern() {}
explicit Pattern(NodePtr<tvm::Node> p) : NodeRef(p) {}
using ContainerType = PatternNode;
};
/*! \brief A wildcard pattern: Accepts all input and binds nothing. */
class PatternWildcard;
/*! \brief PatternWildcard container node */
class PatternWildcardNode : public PatternNode {
public:
PatternWildcardNode() {}
TVM_DLL static PatternWildcard make();
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.PatternWildcard";
TVM_DECLARE_NODE_TYPE_INFO(PatternWildcardNode, PatternNode);
};
RELAY_DEFINE_NODE_REF(PatternWildcard, PatternWildcardNode, Pattern);
/*! \brief A var pattern. Accept all input and bind to a var. */
class PatternVar;
/*! \brief PatternVar container node */
class PatternVarNode : public PatternNode {
public:
PatternVarNode() {}
/*! \brief Variable that stores the matched value. */
tvm::relay::Var var;
TVM_DLL static PatternVar make(tvm::relay::Var var);
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.PatternVar";
TVM_DECLARE_NODE_TYPE_INFO(PatternVarNode, PatternNode);
};
RELAY_DEFINE_NODE_REF(PatternVar, PatternVarNode, Pattern);
/*!
* \brief ADT constructor.
* Constructors compare by pointer equality.
*/
class Constructor;
/*! \brief Constructor container node. */
class ConstructorNode : public ExprNode {
public:
/*! \brief The name (only a hint) */
std::string name_hint;
/*! \brief Input to the constructor. */
tvm::Array<Type> inputs;
/*! \brief The datatype the constructor will construct. */
GlobalTypeVar belong_to;
/*! \brief Index in the table of constructors (set when the type is registered). */
mutable int tag = -1;
ConstructorNode() {}
TVM_DLL static Constructor make(std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to);
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name_hint", &name_hint);
v->Visit("inputs", &inputs);
v->Visit("belong_to", &belong_to);
v->Visit("tag", &tag);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
static constexpr const char* _type_key = "relay.Constructor";
TVM_DECLARE_NODE_TYPE_INFO(ConstructorNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Constructor, ConstructorNode, Expr);
/*! \brief A constructor pattern. Matches a value with the given constructor, binds recursively. */
class PatternConstructor;
/*! \brief PatternVar container node */
class PatternConstructorNode : public PatternNode {
public:
/*! Constructor matched by the pattern. */
Constructor constructor;
/*! Sub-patterns to match against each input to the constructor. */
tvm::Array<Pattern> patterns;
PatternConstructorNode() {}
TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array<Pattern> var);
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("constructor", &constructor);
v->Visit("patterns", &patterns);
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.PatternConstructor";
TVM_DECLARE_NODE_TYPE_INFO(PatternConstructorNode, PatternNode);
};
RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern);
/*!
* \brief Stores all data for an Algebraic Data Type (ADT).
*
* In particular, it stores the handle (global type var) for an ADT
* and the constructors used to build it and is kept in the module. Note
* that type parameters are also indicated in the type data: this means that
* for any instance of an ADT, the type parameters must be indicated. That is,
* an ADT definition is treated as a type-level function, so an ADT handle
* must be wrapped in a TypeCall node that instantiates the type-level arguments.
* The kind checker enforces this.
*/
class TypeData;
/*! \brief TypeData container node */
class TypeDataNode : public TypeNode {
public:
/*!
* \brief The header is simply the name of the ADT.
* We adopt nominal typing for ADT definitions;
* that is, differently-named ADT definitions with same constructors
* have different types.
*/
GlobalTypeVar header;
/*! \brief The type variables (to allow for polymorphism). */
tvm::Array<TypeVar> type_vars;
/*! \brief The constructors. */
tvm::Array<Constructor> constructors;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("header", &header);
v->Visit("type_vars", &type_vars);
v->Visit("constructors", &constructors);
v->Visit("span", &span);
}
TVM_DLL static TypeData make(GlobalTypeVar header,
tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors);
static constexpr const char* _type_key = "relay.TypeData";
TVM_DECLARE_NODE_TYPE_INFO(TypeDataNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TypeData, TypeDataNode, Type);
/*! \brief A clause in a match expression. */
class Clause;
/*! \brief Clause container node. */
class ClauseNode : public Node {
public:
/*! \brief The pattern the clause matches. */
Pattern lhs;
/*! \brief The resulting value. */
Expr rhs;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("lhs", &lhs);
v->Visit("rhs", &rhs);
}
TVM_DLL static Clause make(Pattern lhs, Expr rhs);
static constexpr const char* _type_key = "relay.Clause";
TVM_DECLARE_NODE_TYPE_INFO(ClauseNode, Node);
};
RELAY_DEFINE_NODE_REF(Clause, ClauseNode, NodeRef);
/*! \brief ADT pattern matching exression. */
class Match;
/*! \brief Match container node. */
class MatchNode : public ExprNode {
public:
/*! \brief The input being deconstructed. */
Expr data;
/*! \brief The match node clauses. */
tvm::Array<Clause> clauses;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data);
v->Visit("clause", &clauses);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern);
static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_NODE_TYPE_INFO(MatchNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(Match, MatchNode, Expr);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ADT_H_
......@@ -10,6 +10,7 @@
#include <tvm/node/ir_functor.h>
#include <string>
#include "./expr.h"
#include "./adt.h"
#include "./op.h"
#include "./error.h"
......@@ -92,6 +93,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
......@@ -114,6 +117,8 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode);
RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode);
return vtable;
}
};
......@@ -142,7 +147,11 @@ class ExprVisitor
void VisitExpr_(const RefCreateNode* op) override;
void VisitExpr_(const RefReadNode* op) override;
void VisitExpr_(const RefWriteNode* op) override;
void VisitExpr_(const ConstructorNode* op) override;
void VisitExpr_(const MatchNode* op) override;
virtual void VisitType(const Type& t);
virtual void VisitClause(const Clause& c);
virtual void VisitPattern(const Pattern& c);
protected:
// Internal visiting counter
......@@ -180,6 +189,9 @@ class ExprMutator
Expr VisitExpr_(const RefCreateNode* op) override;
Expr VisitExpr_(const RefReadNode* op) override;
Expr VisitExpr_(const RefWriteNode* op) override;
Expr VisitExpr_(const ConstructorNode* op) override;
Expr VisitExpr_(const MatchNode* op) override;
/*!
* \brief Used to visit the types inside of expressions.
*
......@@ -188,6 +200,8 @@ class ExprMutator
* visitor for types which transform them appropriately.
*/
virtual Type VisitType(const Type& t);
virtual Clause VisitClause(const Clause& c);
virtual Pattern VisitPattern(const Pattern& c);
protected:
/*! \brief Internal map used for memoization. */
......
......@@ -160,6 +160,28 @@ struct RefValueNode : ValueNode {
RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
/*! \brief An ADT constructor value. */
class ConstructorValue;
struct ConstructorValueNode : ValueNode {
Constructor constructor;
tvm::Array<Value> fields;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("constructor", &constructor);
v->Visit("fields", &fields);
}
TVM_DLL static ConstructorValue make(Constructor constructor,
tvm::Array<Value> fields);
static constexpr const char* _type_key = "relay.ConstructorValue";
TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(ConstructorValue, ConstructorValueNode, Value);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_INTERPRETER_H_
......@@ -9,6 +9,7 @@
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/adt.h>
#include <tvm/relay/op.h>
#include <tvm/relay/type.h>
#include <string>
......@@ -35,13 +36,15 @@ struct Module;
*
* The functional style allows users to construct custom
* environments easily, for example each thread can store
* an Module while auto-tuning.
* a Module while auto-tuning.
* */
class ModuleNode : public RelayNode {
public:
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, Function> functions;
/*! \brief A map from global type vars to ADT type data. */
tvm::Map<GlobalTypeVar, TypeData> type_definitions;
/*! \brief The entry function (i.e. "main"). */
GlobalVar entry_func;
......@@ -50,15 +53,18 @@ class ModuleNode : public RelayNode {
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions);
v->Visit("type_definitions", &type_definitions);
v->Visit("global_var_map_", &global_var_map_);
v->Visit("entry_func", &entry_func);
v->Visit("global_type_var_map_", &global_type_var_map_);
}
TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs);
TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs);
/*!
* \brief Add a function to the global environment.
* \param var The name of the global function.
* \param var The var of the global function.
* \param func The function.
* \param update Controls whether you can replace a definition in the
* environment.
......@@ -66,6 +72,13 @@ class ModuleNode : public RelayNode {
void Add(const GlobalVar& var, const Function& func, bool update = false);
/*!
* \brief Add a type-level definition to the global environment.
* \param var The var of the global type definition.
* \param type The type definition.
*/
void AddDef(const GlobalTypeVar& var, const TypeData& type);
/*!
* \brief Add a function to the global environment.
* \param var The name of the global function.
* \param func The function.
......@@ -95,6 +108,13 @@ class ModuleNode : public RelayNode {
GlobalVar GetGlobalVar(const std::string& str);
/*!
* \brief Look up a global function by its name.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
GlobalTypeVar GetGlobalTypeVar(const std::string& str);
/*!
* \brief Lookup a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
......@@ -109,6 +129,20 @@ class ModuleNode : public RelayNode {
Function Lookup(const std::string& name);
/*!
* \brief Lookup a global type definition by its variable.
* \param var The var of the global type definition.
* \return The type definition.
*/
TypeData LookupDef(const GlobalTypeVar& var);
/*!
* \brief Lookup a global type definition by its name.
* \param var The name of the global type definition.
* \return The type definition.
*/
TypeData LookupDef(const std::string& var);
/*!
* \brief Update the functions inside this environment by
* functions in another environment.
* \param other The other environment.
......@@ -137,6 +171,11 @@ class ModuleNode : public RelayNode {
* ensures global uniqueness.
*/
tvm::Map<std::string, GlobalVar> global_var_map_;
/*! \brief A map from string names to global type variables (ADT names)
* that ensures global uniqueness.
*/
tvm::Map<std::string, GlobalTypeVar> global_type_var_map_;
};
struct Module : public NodeRef {
......
......@@ -422,7 +422,7 @@ inline OpRegistry& OpRegistry::add_type_rel(
std::string input_name_prefix = "in";
for (int i = 0; i < get()->num_inputs; i++) {
auto name = input_name_prefix + std::to_string(i);
auto param = TypeVarNode::make(name, TypeVarNode::Kind::kType);
auto param = TypeVarNode::make(name, Kind::kType);
type_params.push_back(param);
arg_types.push_back(param);
}
......@@ -430,7 +430,7 @@ inline OpRegistry& OpRegistry::add_type_rel(
Array<Type> ty_call_args = arg_types;
// Add output type.
auto out_param = TypeVarNode::make("out", TypeVarNode::Kind::kType);
auto out_param = TypeVarNode::make("out", Kind::kType);
type_params.push_back(out_param);
// this will trigger copy on write.
ty_call_args.push_back(out_param);
......
......@@ -56,9 +56,9 @@ TVM_DLL Function InferType(const Function& f, const Module& mod,
* \param t The type to check.
* \param mod The global module.
*
* \return true if the rules are satisified otherwise false
* \return The kind of the passed type.
*/
TVM_DLL bool KindCheck(const Type& t, const Module& mod);
TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
/*! \brief Compare two expressions for structural equivalence.
*
......@@ -144,10 +144,11 @@ TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);
* type in the context.
*
* \param expr the expression.
* \param mod the module.
*
* \return List of free vars, in the PostDFS order visited by expr.
*/
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod);
/*! \brief Get free TypeVars from type t.
*
......@@ -155,10 +156,11 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
* type in the context.
*
* \param t the type.
* \param mod the module.
*
* \return List of free type vars, in the PostDFS order visited by type.
*/
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t);
TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t, const Module& mod);
/*! \brief Get all bound type variables from expression expr.
*
......@@ -166,10 +168,11 @@ TVM_DLL tvm::Array<TypeVar> FreeTypeVars(const Type& t);
* They only have meaning inside that expr, and can only be used in it.
*
* \param expr the expression.
* \param mod the module.
*
* \return List of bound type vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr);
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod);
/*! \brief Get all bound type variables from type t.
*
......@@ -177,26 +180,29 @@ TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Expr& expr);
* They only have meaning inside that type, and can only be used in it.
*
* \param t the type
* \param mod the module.
*
* \return List of bound type vars, in the PostDFS order visited by type.
*/
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t);
TVM_DLL tvm::Array<TypeVar> BoundTypeVars(const Type& t, const Module& mod);
/*! \brief Get all type variables in expression expr.
*
* \param expr the expression.
* \param mod the module.
*
* \return List of type vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr);
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod);
/*! \brief Get all type variables in type t.
*
* \param t the type.
* \param mod the module.
*
* \return List of type vars, in the PostDFS order visited by type.
*/
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t);
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const Module& mod);
/*! \brief Remove expressions which does not effect the program result.
*
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/pattern_functor.h
* \brief A more powerful visitor on ADT patterns that enables defining
* arbitrary function signatures with type-based dispatch on first argument.
*/
#ifndef TVM_RELAY_PATTERN_FUNCTOR_H_
#define TVM_RELAY_PATTERN_FUNCTOR_H_
#include <tvm/node/ir_functor.h>
#include <string>
#include "./expr.h"
#include "./op.h"
#include "./error.h"
#include "./adt.h"
namespace tvm {
namespace relay {
/*!
* \brief A dynamical functor on ADT patterns that dispatches on its first argument.
* You can use this as a more powerful visitor, since it allows you to
* define the types of further arguments to VisitPattern.
*
* \sa tvm/ir_functor.h
*
* \tparam FType function signiture
* This type is only defined for FType with function signature R(const Pattern&,
* Args...)
*/
template <typename FType>
class PatternFunctor;
// functions to be overriden.
#define PATTERN_FUNCTOR_DEFAULT \
{ return VisitPatternDefault_(op, std::forward<Args>(args)...); }
#define RELAY_PATTERN_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitPattern_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
});
template <typename R, typename... Args>
class PatternFunctor<R(const Pattern& n, Args...)> {
private:
using TSelf = PatternFunctor<R(const Pattern& n, Args...)>;
using FType = tvm::IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*! \brief virtual destructor */
virtual ~PatternFunctor() {}
/*!
* \brief Same as call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
R operator()(const Pattern& n, Args... args) {
return VisitPattern(n, std::forward<Args>(args)...);
}
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitPattern(const Pattern& n, Args... args) {
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
virtual R VisitPattern_(const PatternWildcardNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPattern_(const PatternVarNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPattern_(const PatternConstructorNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPatternDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
// Set dispatch
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternWildcardNode);
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternVarNode);
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternConstructorNode);
return vtable;
}
};
/*! \brief A simple visitor wrapper around PatternFunctor.
*
* Exposes two visitors with default traversal strategies, one
* which doesn't compute a result but can mutate internal state,
* and another which functionally builds a new pattern.
*/
class PatternVisitor : public ::tvm::relay::PatternFunctor<void(const Pattern& n)> {
public:
void VisitPattern_(const PatternWildcardNode* op) override;
void VisitPattern_(const PatternVarNode* op) override;
void VisitPattern_(const PatternConstructorNode* op) override;
virtual void VisitType(const Type& t);
virtual void VisitVar(const Var& v);
virtual void VisitConstructor(const Constructor& c);
};
/*! \brief A wrapper around ExprFunctor which functionally updates the AST.
*
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/
class PatternMutator
: public ::tvm::relay::PatternFunctor<Pattern(const Pattern&)> {
public:
Pattern Mutate(const Pattern& pat);
Pattern VisitPattern_(const PatternWildcardNode* op) override;
Pattern VisitPattern_(const PatternVarNode* op) override;
Pattern VisitPattern_(const PatternConstructorNode* op) override;
/*! \brief Used to visit the types inside of patterns.
*
* Can be overloaded to transform the types in arbitrary
* ways, one way would be to define a sub-class of type
* visitor for types which transform them appropriately.
*/
virtual Type VisitType(const Type& t);
/*! \brief Used to visit the vars inside of patterns. */
virtual Var VisitVar(const Var& v);
/*! \brief Used to visit the vars inside of patterns. */
virtual Constructor VisitConstructor(const Constructor& c);
private:
std::unordered_map<Var, Var, NodeHash, NodeEqual> var_map_;
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PATTERN_FUNCTOR_H_
......@@ -98,6 +98,18 @@ class TensorTypeNode : public BaseTensorTypeNode {
RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
/*! \brief possible kinds of Type */
enum Kind : int {
/*! \brief template variable in shape expression */
kType = 0,
kShapeVar = 1,
kBaseType = 2,
kShape = 3,
kConstraint = 4,
kAdtHandle = 5,
kTypeData = 6
};
/*!
* \brief Type parameter in the function.
* This can be viewed as template parameter in c++ template function.
......@@ -119,14 +131,6 @@ class TypeVar;
/*! \brief TypeVar container node */
class TypeVarNode : public TypeNode {
public:
/*! \brief possible kinds of TypeVar */
enum Kind : int {
/*! \brief template variable in shape expression */
kType = 0,
kShapeVar = 1,
kBaseType = 2,
kShape = 3
};
/*!
* \brief The variable itself is only meaningful when
* kind is ShapeVar, otherwise, we only use the name.
......@@ -150,6 +154,63 @@ class TypeVarNode : public TypeNode {
RELAY_DEFINE_NODE_REF(TypeVar, TypeVarNode, Type);
/*!
* \brief A global type variable that is used for defining new types or type aliases.
*/
class GlobalTypeVar;
/*! \brief GlobalTypeVar container node */
class GlobalTypeVarNode : public TypeNode {
public:
/*!
* \brief The variable itself is only meaningful when
* kind is ShapeVar; otherwise, we only use the name.
*/
tvm::Var var;
/*! \brief The kind of type parameter */
Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("kind", &kind);
v->Visit("span", &span);
}
TVM_DLL static GlobalTypeVar make(std::string name, Kind kind);
static constexpr const char* _type_key = "relay.GlobalTypeVar";
TVM_DECLARE_NODE_TYPE_INFO(GlobalTypeVarNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(GlobalTypeVar, GlobalTypeVarNode, Type);
/*!
* \brief Type application.
*/
class TypeCall;
/*! \brief TypeCall container node */
class TypeCallNode : public TypeNode {
public:
/*!
* \brief The type-level function (ADT that takes type params).
*/
Type func;
/*! \brief The arguments. */
tvm::Array<Type> args;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("func", &func);
v->Visit("args", &args);
v->Visit("span", &span);
}
TVM_DLL static TypeCall make(Type func, tvm::Array<Type> args);
static constexpr const char* _type_key = "relay.TypeCall";
TVM_DECLARE_NODE_TYPE_INFO(TypeCallNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TypeCall, TypeCallNode, Type);
/*!
* \brief IncompleteType.
* This is intermediate values that is used during type inference.
*
......@@ -162,14 +223,14 @@ class IncompleteType;
/*! \brief IncompleteType container node */
class IncompleteTypeNode : public TypeNode {
public:
TypeVarNode::Kind kind;
Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("kind", &kind);
v->Visit("span", &span);
}
TVM_DLL static IncompleteType make(TypeVarNode::Kind kind);
TVM_DLL static IncompleteType make(Kind kind);
static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode);
......
......@@ -7,8 +7,10 @@ from . import ty
from . import expr
from . import expr_functor
from . import module
from . import adt
from . import ir_pass
from .build_module import build, build_config, create_executor, optimize
from . import prelude
from . import parser
from . import debug
......@@ -45,6 +47,8 @@ TypeRelation = ty.TypeRelation
IncompleteType = ty.IncompleteType
scalar_type = ty.scalar_type
RefType = ty.RefType
GlobalTypeVar = ty.GlobalTypeVar
TypeCall = ty.TypeCall
# Expr
Expr = expr.Expr
......@@ -61,6 +65,15 @@ RefCreate = expr.RefCreate
RefRead = expr.RefRead
RefWrite = expr.RefWrite
# ADT
PatternWildcard = adt.PatternWildcard
PatternVar = adt.PatternVar
PatternConstructor = adt.PatternConstructor
Constructor = adt.Constructor
TypeData = adt.TypeData
Clause = adt.Clause
Match = adt.Match
# helper functions
var = expr.var
const = expr.const
......
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""Algebraic data types in Relay."""
from .base import RelayNode, register_relay_node, NodeBase
from . import _make
from .ty import Type
from .expr import Expr, Call
class Pattern(RelayNode):
"""Base type for pattern matching constructs."""
pass
@register_relay_node
class PatternWildcard(Pattern):
"""Wildcard pattern in Relay: Matches any ADT and binds nothing."""
def __init__(self):
"""Constructs a wildcard pattern.
Parameters
----------
None
Returns
-------
wildcard: PatternWildcard
a wildcard pattern.
"""
self.__init_handle_by_constructor__(_make.PatternWildcard)
@register_relay_node
class PatternVar(Pattern):
"""Variable pattern in Relay: Matches anything and binds it to the variable."""
def __init__(self, var):
"""Construct a variable pattern.
Parameters
----------
var: tvm.relay.Var
Returns
-------
pv: PatternVar
A variable pattern.
"""
self.__init_handle_by_constructor__(_make.PatternVar, var)
@register_relay_node
class PatternConstructor(Pattern):
"""Constructor pattern in Relay: Matches an ADT of the given constructor, binds recursively."""
def __init__(self, constructor, patterns=None):
"""Construct a constructor pattern.
Parameters
----------
constructor: Constructor
The constructor.
patterns: Optional[List[Pattern]]
Optional subpatterns: for each field of the constructor,
match to the given subpattern (treated as a variable pattern by default).
Returns
-------
wildcard: PatternWildcard
a wildcard pattern.
"""
if patterns is None:
patterns = []
self.__init_handle_by_constructor__(_make.PatternConstructor, constructor, patterns)
@register_relay_node
class Constructor(Expr):
"""Relay ADT constructor."""
def __init__(self, name_hint, inputs, belong_to):
"""Defines an ADT constructor.
Parameters
----------
name_hint : str
Name of constructor (only a hint).
inputs : List[Type]
Input types.
belong_to : tvm.relay.GlobalTypeVar
Denotes which ADT the constructor belongs to.
Returns
-------
con: Constructor
A constructor.
"""
self.__init_handle_by_constructor__(_make.Constructor, name_hint, inputs, belong_to)
def __call__(self, *args):
"""Call the constructor.
Parameters
----------
args: List[relay.Expr]
The arguments to the constructor.
Returns
-------
call: relay.Call
A call to the constructor.
"""
return Call(self, args)
@register_relay_node
class TypeData(Type):
"""Stores the definition for an Algebraic Data Type (ADT) in Relay.
Note that ADT definitions are treated as type-level functions because
the type parameters need to be given for an instance of the ADT. Thus,
any global type var that is an ADT header needs to be wrapped in a
type call that passes in the type params.
"""
def __init__(self, header, type_vars, constructors):
"""Defines a TypeData object.
Parameters
----------
header: tvm.relay.GlobalTypeVar
The name of the ADT.
ADTs with the same constructors but different names are
treated as different types.
type_vars: List[TypeVar]
Type variables that appear in constructors.
constructors: List[tvm.relay.Constructor]
The constructors for the ADT.
Returns
-------
type_data: TypeData
The adt declaration.
"""
self.__init_handle_by_constructor__(_make.TypeData, header, type_vars, constructors)
@register_relay_node
class Clause(NodeBase):
"""Clause for pattern matching in Relay."""
def __init__(self, lhs, rhs):
"""Construct a clause.
Parameters
----------
lhs: tvm.relay.Pattern
Left-hand side of match clause.
rhs: tvm.relay.Expr
Right-hand side of match clause.
Returns
-------
clause: Clause
The Clause.
"""
self.__init_handle_by_constructor__(_make.Clause, lhs, rhs)
@register_relay_node
class Match(Expr):
"""Pattern matching expression in Relay."""
def __init__(self, data, clauses):
"""Construct a Match.
Parameters
----------
data: tvm.relay.Expr
The value being deconstructed and matched.
clauses: List[tvm.relay.Clause]
The pattern match clauses.
Returns
-------
match: tvm.relay.Expr
The match expression.
"""
self.__init_handle_by_constructor__(_make.Match, data, clauses)
......@@ -292,6 +292,12 @@ class GraphRuntimeCodegen(ExprFunctor):
def visit_ref_write(self, _):
raise RuntimeError("reference not supported")
def visit_constructor(self, _):
raise Exception("ADT constructor case not yet implemented")
def visit_match(self, _):
raise Exception("match case not yet implemented")
def _get_json(self):
"""
Convert the sequence of nodes stored by the compiler into the
......
......@@ -53,6 +53,13 @@ class Closure(Value):
@register_relay_node
class ConstructorValue(Value):
def __init__(self, constructor, fields, types):
self.__init_handle_by_constructor__(
_make.ConstructorValue, constructor, fields, types)
@register_relay_node
class TensorValue(Value):
"""A Tensor value produced by the interpreter."""
......
......@@ -172,6 +172,20 @@ class Var(Expr):
name = self.vid.name_hint
return name
def __call__(self, *args):
"""Call the variable (if it represents a function).
Parameters
----------
args: List[relay.Expr]
The arguments to the call.
Returns
-------
call: Call
A call taking the variable as a function.
"""
return Call(self, args)
@register_relay_node
class GlobalVar(Expr):
......
......@@ -2,6 +2,7 @@
"""The expression functor of Relay."""
from .expr import Function, Call, Let, Var, GlobalVar, If, Tuple, TupleGetItem, Constant
from .adt import Constructor, Match, Clause
from .op import Op
class ExprFunctor:
......@@ -47,6 +48,10 @@ class ExprFunctor:
res = self.visit_ref_read(expr)
elif isinstance(expr, RefWrite):
res = self.visit_ref_write(expr)
elif isinstance(expr, Constructor):
res = self.visit_constructor(expr)
elif isinstance(expr, Match):
res = self.visit_match(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))
......@@ -96,6 +101,13 @@ class ExprFunctor:
def visit_ref_read(self, _):
raise NotImplementedError()
def visit_constructor(self, _):
raise NotImplementedError()
def visit_match(self, _):
raise NotImplementedError()
class ExprMutator(ExprFunctor):
"""
A functional visitor over Expr.
......
......@@ -9,6 +9,7 @@ from . import _ir_pass
from . import _make
from .expr import Expr
from .ty import Type
from .module import Module
def post_order_visit(expr, fvisit):
"""Recursively visit the ir in post DFS order node,
......@@ -107,7 +108,7 @@ def well_formed(expr):
def check_kind(t, mod=None):
"""Check that the type is well kinded.
"""Check that the type is well kinded and return the kind.
For example, this mean type cannot has tensor of tensor, or is a tuple type of 2 shapes.
Parameters
......@@ -120,15 +121,15 @@ def check_kind(t, mod=None):
Returns
-------
well_kinded : bool
whether the input type is well kinded.
kind : Kind
the kind of t
Examples
--------
.. code:: python
assert not check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)]))
assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)]))
assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)])) == Shape
assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) == Type
"""
if mod is not None:
return _ir_pass.check_kind(t, mod)
......@@ -190,52 +191,61 @@ def all_vars(expr):
return _ir_pass.all_vars(expr)
def free_type_vars(expr):
def free_type_vars(expr, mod=None):
"""Get free type variables from expression/type e
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
The global module
Returns
-------
free : List[tvm.relay.TypeVar]
The list of free type variables in post-DFS order
"""
return _ir_pass.free_type_vars(expr)
use_mod = mod if mod is not None else Module()
return _ir_pass.free_type_vars(expr, use_mod)
def bound_type_vars(expr):
def bound_type_vars(expr, mod=None):
"""Get bound type variables from expression/type e
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
The global module
Returns
-------
free : List[tvm.relay.TypeVar]
The list of bound type variables in post-DFS order
"""
return _ir_pass.bound_type_vars(expr)
use_mod = mod if mod is not None else Module()
return _ir_pass.bound_type_vars(expr, use_mod)
def all_type_vars(expr):
def all_type_vars(expr, mod=None):
"""Get all type variables from expression/type e
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
The global module
Returns
-------
free : List[tvm.relay.TypeVar]
The list of all type variables in post-DFS order
"""
return _ir_pass.all_type_vars(expr)
use_mod = mod if mod is not None else Module()
return _ir_pass.all_type_vars(expr, use_mod)
def simplify_inference(expr):
......
......@@ -6,6 +6,7 @@ from . import _make
from . import _module
from . import expr as _expr
from . import ty as _ty
@register_relay_node
class Module(RelayNode):
......@@ -20,7 +21,7 @@ class Module(RelayNode):
functions : dict, optional.
Map of global var to Function
"""
def __init__(self, functions=None):
def __init__(self, functions=None, type_definitions=None):
if functions is None:
functions = {}
elif isinstance(functions, dict):
......@@ -32,28 +33,46 @@ class Module(RelayNode):
raise TypeError("Expect functions to be Dict[GlobalVar, Function]")
mapped_funcs[k] = v
functions = mapped_funcs
self.__init_handle_by_constructor__(_make.Module, functions)
if type_definitions is None:
type_definitions = {}
elif isinstance(type_definitions, dict):
mapped_type_defs = {}
for k, v in type_definitions.items():
if isinstance(k, _base.string_types):
k = _ty.GlobalTypeVar(k)
if not isinstance(k, _ty.GlobalTypeVar):
raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
mapped_type_defs[k] = v
type_definitions = mapped_type_defs
self.__init_handle_by_constructor__(_make.Module, functions, type_definitions)
def __setitem__(self, var, func):
"""Add a function to the module.
def __setitem__(self, var, val):
"""Add a mapping to the module.
Parameters
---------
var: GlobalVar
The global variable which names the function.
The global variable.
func: Function
The function.
val: Union[Function, Type]
The value.
"""
return self._add(var, func)
return self._add(var, val)
def _add(self, var, func, update=False):
if isinstance(var, _base.string_types):
var = _expr.GlobalVar(var)
return _module.Module_Add(self, var, func, update)
def _add(self, var, val, update=False):
if isinstance(val, _expr.Function):
if isinstance(var, _base.string_types):
var = _expr.GlobalVar(var)
_make.Module_Add(self, var, val, update)
else:
assert isinstance(val, _ty.Type)
if isinstance(var, _base.string_types):
var = _ty.GlobalTypeVar(var)
_module.Module_AddDef(self, var, val)
def __getitem__(self, var):
"""Lookup a global function by name or by variable.
"""Lookup a global definition by name or by variable.
Parameters
----------
......@@ -62,13 +81,15 @@ class Module(RelayNode):
Returns
-------
func: Function
The function referenced by :code:`var`.
val: Union[Function, Type]
The definition referenced by :code:`var` (either a function or type).
"""
if isinstance(var, _base.string_types):
return _module.Module_Lookup_str(self, var)
else:
elif isinstance(var, _expr.GlobalVar):
return _module.Module_Lookup(self, var)
else:
return _module.Module_LookupDef(self, var)
def update(self, other):
"""Insert functions in another Module to current one.
......@@ -100,3 +121,22 @@ class Module(RelayNode):
tvm.TVMError if we cannot find corresponding global var.
"""
return _module.Module_GetGlobalVar(self, name)
def get_global_type_var(self, name):
"""Get a global type variable in the function by name.
Parameters
----------
name: str
The name of the global type variable.
Returns
-------
global_type_var: GlobalTypeVar
The global variable mapped to :code:`name`.
Raises
------
tvm.TVMError if we cannot find corresponding global type var.
"""
return _module.Module_GetGlobalTypeVar(self, name)
......@@ -21,6 +21,19 @@ class Type(RelayNode):
"""Compares two Relay types by referential equality."""
return super().__eq__(other)
def __call__(self, *args):
"""Create a type call from this type.
Parameters
----------
args: List[relay.Type]
The arguments to the type call.
Returns
-------
call: relay.TypeCall
"""
return TypeCall(self, args)
@register_relay_node
class TensorType(Type):
......@@ -75,6 +88,9 @@ class Kind(IntEnum):
ShapeVar = 1
BaseType = 2
Shape = 3
Constraint = 4
AdtHandle = 5
TypeData = 6
@register_relay_node
class TypeVar(Type):
......@@ -107,6 +123,53 @@ class TypeVar(Type):
@register_relay_node
class GlobalTypeVar(Type):
"""A global type variable in Relay.
GlobalTypeVar is used to refer to the global type-level definitions
stored in the environment.
"""
def __init__(self, var, kind=Kind.AdtHandle):
"""Construct a GlobalTypeVar.
Parameters
----------
var: tvm.Var
The tvm.Var which backs the type parameter.
kind: Kind, optional
The kind of the type parameter, Kind.AdtHandle by default.
Returns
-------
type_var: GlobalTypeVar
The global type variable.
"""
self.__init_handle_by_constructor__(_make.GlobalTypeVar, var, kind)
@register_relay_node
class TypeCall(Type):
"""Type-level function application in Relay.
A type call applies argument types to a constructor (type-level function).
"""
def __init__(self, func, args):
"""Construct a TypeCall.
Parameters
----------
func: tvm.relay.Type
The function.
args: List[tvm.expr.Type]
The arguments.
Returns
-------
type_call: TypeCall
The type function application.
"""
self.__init_handle_by_constructor__(_make.TypeCall, func, args)
@register_relay_node
class TypeConstraint(Type):
"""Abstract class representing a type constraint."""
pass
......
......@@ -6,6 +6,7 @@
#include <tvm/packed_func_ext.h>
#include <tvm/runtime/device_api.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/attrs/debug.h>
......@@ -92,6 +93,26 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "RefValueNode(" << node->value << ")";
});
ConstructorValue ConstructorValueNode::make(Constructor constructor,
tvm::Array<Value> fields) {
NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
n->constructor = constructor;
n->fields = fields;
return ConstructorValue(n);
}
TVM_REGISTER_API("relay._make.ConstructorValue")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ConstructorValueNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
tvm::IRPrinter* p) {
p->stream << "ConstructorValueNode(" << node->constructor
<< node->fields << ")";
});
/*!
* \brief A stack frame in the Relay interpreter.
*
......@@ -185,7 +206,8 @@ InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) {
//
// Conversion to ANF is recommended before running the interpretation.
class Interpreter :
public ExprFunctor<Value(const Expr& n)> {
public ExprFunctor<Value(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const Value& v)> {
public:
Interpreter(Module mod,
DLContext context,
......@@ -209,7 +231,7 @@ class Interpreter :
}
Value Eval(const Expr& expr) {
return (*this)(expr);
return VisitExpr(expr);
}
Value VisitExpr(const Expr& expr) final {
......@@ -401,6 +423,9 @@ class Interpreter :
<< "; operators should be removed by future passes; try "
"fusing and lowering";
}
if (auto con = call->op.as<ConstructorNode>()) {
return ConstructorValueNode::make(GetRef<Constructor>(con), args);
}
// Now we just evaluate and expect to find a closure.
Value fn_val = Eval(call->op);
if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
......@@ -474,6 +499,44 @@ class Interpreter :
}
}
Value VisitExpr_(const MatchNode* op) final {
Value v = Eval(op->data);
for (const Clause& c : op->clauses) {
if (VisitPattern(c->lhs, v)) {
return VisitExpr(c->rhs);
}
}
LOG(FATAL) << "did not find any match";
return Value();
}
bool VisitPattern_(const PatternConstructorNode* op, const Value& v) final {
const ConstructorValueNode* cvn = v.as<ConstructorValueNode>();
CHECK(cvn) << "need to be a constructor for match";
CHECK_NE(op->constructor->tag, -1);
CHECK_NE(cvn->constructor->tag, -1);
if (op->constructor->tag == cvn->constructor->tag) {
// todo(M.K.): should use ptr equality but it is broken
CHECK(op->patterns.size() == cvn->fields.size());
for (size_t i = 0; i < op->patterns.size(); ++i) {
if (!VisitPattern(op->patterns[i], cvn->fields[i])) {
return false;
}
}
return true;
}
return false;
}
bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final {
return true;
}
bool VisitPattern_(const PatternVarNode* op, const Value& v) final {
extend(op->var, v);
return true;
}
InterpreterState get_state(Expr e = Expr()) const {
InterpreterStateNode::Stack stack;
for (auto fr : this->stack_.frames) {
......@@ -485,14 +548,14 @@ class Interpreter :
}
private:
// module
// Module
Module mod_;
// For simplicity we only run the interpreter on a single context.
// Context to run the interpreter on.
DLContext context_;
// Target parameter being used by the interpreter.
Target target_;
// value stack.
// Value stack.
Stack stack_;
// Backend compile engine.
CompileEngine engine_;
......
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/ir/adt.cc
* \brief AST nodes for Relay algebraic data types (ADTs).
*/
#include <tvm/relay/type.h>
#include <tvm/relay/adt.h>
namespace tvm {
namespace relay {
PatternWildcard PatternWildcardNode::make() {
NodePtr<PatternWildcardNode> n = make_node<PatternWildcardNode>();
return PatternWildcard(n);
}
TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
TVM_REGISTER_API("relay._make.PatternWildcard")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = PatternWildcardNode::make();
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternWildcardNode>([](const PatternWildcardNode* node,
tvm::IRPrinter* p) {
p->stream << "PatternWildcardNode()";
});
PatternVar PatternVarNode::make(tvm::relay::Var var) {
NodePtr<PatternVarNode> n = make_node<PatternVarNode>();
n->var = std::move(var);
return PatternVar(n);
}
TVM_REGISTER_NODE_TYPE(PatternVarNode);
TVM_REGISTER_API("relay._make.PatternVar")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = PatternVarNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternVarNode>([](const PatternVarNode* node,
tvm::IRPrinter* p) {
p->stream << "PatternVarNode(" << node->var << ")";
});
PatternConstructor PatternConstructorNode::make(Constructor constructor,
tvm::Array<Pattern> patterns) {
NodePtr<PatternConstructorNode> n = make_node<PatternConstructorNode>();
n->constructor = std::move(constructor);
n->patterns = std::move(patterns);
return PatternConstructor(n);
}
TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
TVM_REGISTER_API("relay._make.PatternConstructor")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = PatternConstructorNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternConstructorNode>([](const PatternConstructorNode* node,
tvm::IRPrinter* p) {
p->stream << "PatternConstructorNode(" << node->constructor
<< ", " << node->patterns << ")";
});
Constructor ConstructorNode::make(std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to) {
NodePtr<ConstructorNode> n = make_node<ConstructorNode>();
n->name_hint = std::move(name_hint);
n->inputs = std::move(inputs);
n->belong_to = std::move(belong_to);
return Constructor(n);
}
TVM_REGISTER_NODE_TYPE(ConstructorNode);
TVM_REGISTER_API("relay._make.Constructor")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ConstructorNode::make(args[0], args[1], args[2]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstructorNode>([](const ConstructorNode* node,
tvm::IRPrinter* p) {
p->stream << "ConstructorNode(" << node->name_hint << ", "
<< node->inputs << ", " << node->belong_to << ")";
});
TypeData TypeDataNode::make(GlobalTypeVar header,
tvm::Array<TypeVar> type_vars,
tvm::Array<Constructor> constructors) {
NodePtr<TypeDataNode> n = make_node<TypeDataNode>();
n->header = std::move(header);
n->type_vars = std::move(type_vars);
n->constructors = std::move(constructors);
return TypeData(n);
}
TVM_REGISTER_NODE_TYPE(TypeDataNode);
TVM_REGISTER_API("relay._make.TypeData")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TypeDataNode::make(args[0], args[1], args[2]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeDataNode>([](const TypeDataNode* node,
tvm::IRPrinter* p) {
p->stream << "TypeDataNode(" << node->header << ", " << node->type_vars << ", "
<< node->constructors << ")";
});
Clause ClauseNode::make(Pattern lhs, Expr rhs) {
NodePtr<ClauseNode> n = make_node<ClauseNode>();
n->lhs = std::move(lhs);
n->rhs = std::move(rhs);
return Clause(n);
}
TVM_REGISTER_NODE_TYPE(ClauseNode);
TVM_REGISTER_API("relay._make.Clause")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ClauseNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ClauseNode>([](const ClauseNode* node,
tvm::IRPrinter* p) {
p->stream << "ClauseNode(" << node->lhs << ", "
<< node->rhs << ")";
});
Match MatchNode::make(Expr data, tvm::Array<Clause> clauses) {
NodePtr<MatchNode> n = make_node<MatchNode>();
n->data = std::move(data);
n->clauses = std::move(clauses);
return Match(n);
}
TVM_REGISTER_NODE_TYPE(MatchNode);
TVM_REGISTER_API("relay._make.Match")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = MatchNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<MatchNode>([](const MatchNode* node,
tvm::IRPrinter* p) {
p->stream << "MatchNode(" << node->data << ", "
<< node->clauses << ")";
});
} // namespace relay
} // namespace tvm
......@@ -5,6 +5,7 @@
*/
#include <tvm/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/pass.h>
#include "type_functor.h"
......@@ -17,7 +18,8 @@ namespace relay {
class AlphaEqualHandler:
public AttrsEqualHandler,
public TypeFunctor<bool(const Type&, const Type&)>,
public ExprFunctor<bool(const Expr&, const Expr&)> {
public ExprFunctor<bool(const Expr&, const Expr&)>,
public PatternFunctor<bool(const Pattern&, const Pattern&)> {
public:
explicit AlphaEqualHandler(bool map_free_var)
: map_free_var_(map_free_var) {}
......@@ -160,7 +162,7 @@ class AlphaEqualHandler:
}
equal_map_[lhs->type_params[i]] = rhs->type_params[i];
// set up type parameter equal
if (lhs->type_params[i]->kind == TypeVarNode::Kind::kShapeVar) {
if (lhs->type_params[i]->kind == Kind::kShapeVar) {
// map variable
equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var;
}
......@@ -215,6 +217,26 @@ class AlphaEqualHandler:
return false;
}
bool VisitType_(const GlobalTypeVarNode* op, const Type& t2) final {
return GetRef<Type>(op) == t2;
}
bool VisitType_(const TypeCallNode* op, const Type& t2) final {
const TypeCallNode* pt = t2.as<TypeCallNode>();
if (pt == nullptr
|| op->args.size() != pt->args.size()
|| !TypeEqual(op->func, pt->func)) {
return false;
}
for (size_t i = 0; i < op->args.size(); ++i) {
if (!TypeEqual(op->args[i], pt->args[i])) {
return false;
}
}
return true;
}
// Expr equal checking.
bool NDArrayEqual(const runtime::NDArray& lhs,
const runtime::NDArray& rhs) {
......@@ -261,11 +283,9 @@ class AlphaEqualHandler:
bool VisitExpr_(const GlobalVarNode* lhs, const Expr& other) final {
if (const GlobalVarNode* rhs = other.as<GlobalVarNode>()) {
// use name equality for global var for now.
if (lhs->name_hint != rhs->name_hint) return false;
return true;
} else {
return false;
return lhs->name_hint == rhs->name_hint;
}
return false;
}
bool VisitExpr_(const TupleNode* lhs, const Expr& other) final {
......@@ -392,6 +412,63 @@ class AlphaEqualHandler:
return false;
}
}
bool VisitExpr_(const ConstructorNode* op, const Expr& e2) final {
return GetRef<Expr>(op) == e2;
}
bool ClauseEqual(const Clause& l, const Clause& r) {
return PatternEqual(l->lhs, r->lhs) && ExprEqual(l->rhs, r->rhs);
}
bool PatternEqual(const Pattern& l, const Pattern& r) {
return VisitPattern(l, r);
}
bool VisitPattern_(const PatternWildcardNode* op, const Pattern& r) final {
return r.as<PatternWildcardNode>();
}
bool VisitPattern_(const PatternVarNode* op, const Pattern& e2) final {
if (const auto* r = e2.as<PatternVarNode>()) {
return MergeVarDecl(op->var, r->var);
}
return false;
}
bool VisitPattern_(const PatternConstructorNode* op, const Pattern& e2) final {
const auto* r = e2.as<PatternConstructorNode>();
if (r == nullptr
|| !ExprEqual(op->constructor, r->constructor)
|| op->patterns.size() != r->patterns.size()) {
return false;
}
for (size_t i = 0; i < op->patterns.size(); i++) {
if (!PatternEqual(op->patterns[i], r->patterns[i])) {
return false;
}
}
return true;
}
bool VisitExpr_(const MatchNode* op, const Expr& e2) final {
const MatchNode* r = e2.as<MatchNode>();
if (r == nullptr
|| !ExprEqual(op->data, r->data)
|| op->clauses.size() != r->clauses.size()) {
return false;
}
for (size_t i = 0; i < op->clauses.size(); ++i) {
if (!ClauseEqual(op->clauses[i], r->clauses[i])) {
return false;
}
}
return true;
}
private:
// whether to map open terms.
bool map_free_var_{false};
......
......@@ -130,9 +130,14 @@ Function FunctionNode::make(tvm::Array<Var> params,
FuncType FunctionNode::func_type_annotation() const {
Array<Type> param_types;
for (auto param : this->params) {
param_types.push_back(param->type_annotation);
Type param_type = (param->type_annotation.defined()) ? param->type_annotation
: IncompleteTypeNode::make(Kind::kType);
param_types.push_back(param_type);
}
return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {});
Type ret_type = (this->ret_type.defined()) ? this->ret_type
: IncompleteTypeNode::make(Kind::kType);
return FuncTypeNode::make(param_types, ret_type, this->type_params, {});
}
bool FunctionNode::IsPrimitive() const {
......
......@@ -185,6 +185,24 @@ Expr ExprMutator::VisitExpr_(const RefWriteNode* op) {
}
}
Expr ExprMutator::VisitExpr_(const ConstructorNode* c) {
return GetRef<Expr>(c);
}
Expr ExprMutator::VisitExpr_(const MatchNode* m) {
std::vector<Clause> clauses;
for (const Clause& p : m->clauses) {
clauses.push_back(VisitClause(p));
}
return MatchNode::make(VisitExpr(m->data), clauses);
}
Clause ExprMutator::VisitClause(const Clause& c) {
return ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs));
}
Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
Type ExprMutator::VisitType(const Type& t) { return t; }
void ExprVisitor::VisitExpr(const Expr& expr) {
......@@ -267,6 +285,27 @@ void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) {
this->VisitExpr(op->value);
}
void ExprVisitor::VisitExpr_(const ConstructorNode* op) {
for (const Type& t : op->inputs) {
this->VisitType(t);
}
this->VisitType(op->belong_to);
}
void ExprVisitor::VisitExpr_(const MatchNode* op) {
this->VisitExpr(op->data);
for (const Clause& c : op->clauses) {
this->VisitClause(c);
}
}
void ExprVisitor::VisitClause(const Clause& op) {
this->VisitPattern(op->lhs);
this->VisitExpr(op->rhs);
}
void ExprVisitor::VisitPattern(const Pattern& p) { return; }
void ExprVisitor::VisitType(const Type& t) { return; }
// visitor to implement apply
......
......@@ -5,6 +5,7 @@
*/
#include <tvm/ir_pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/pass.h>
#include <tvm/attrs.h>
......@@ -18,7 +19,8 @@ namespace relay {
class RelayHashHandler:
public AttrsHashHandler,
public TypeFunctor<size_t(const Type&)>,
public ExprFunctor<size_t(const Expr&)> {
public ExprFunctor<size_t(const Expr&)>,
public PatternFunctor<size_t(const Pattern&)> {
public:
explicit RelayHashHandler() {}
......@@ -201,7 +203,7 @@ class RelayHashHandler:
hash_map_[var] = hash;
const auto* ty_param = var.as<TypeVarNode>();
if (ty_param && ty_param->kind == TypeVarNode::Kind::kShapeVar) {
if (ty_param && ty_param->kind == Kind::kShapeVar) {
hash_map_[ty_param->var] = hash;
}
return hash;
......@@ -236,7 +238,7 @@ class RelayHashHandler:
}
hash = Combine(hash, TypeHash(func->ret_type));
hash = Combine(hash, ExprHash(func->body));
hash = Combine(hash, ExprHash(func->body));
return hash;
}
......@@ -249,6 +251,10 @@ class RelayHashHandler:
hash = Combine(hash, ExprHash(arg));
}
for (auto t : call->type_args) {
hash = Combine(hash, TypeHash(t));
}
hash = Combine(hash, AttrHash(call->attrs));
return hash;
......@@ -304,6 +310,72 @@ class RelayHashHandler:
hash = Combine(hash, ExprHash(rn->value));
return hash;
}
size_t VisitExpr_(const MatchNode* mn) final {
size_t hash = std::hash<std::string>()(MatchNode::_type_key);
hash = Combine(hash, ExprHash(mn->data));
for (const auto& c : mn->clauses) {
hash = Combine(hash, PatternHash(c->lhs));
hash = Combine(hash, ExprHash(c->rhs));
}
return hash;
}
size_t VisitExpr_(const ConstructorNode* cn) final {
size_t hash = std::hash<std::string>()(ConstructorNode::_type_key);
hash = Combine(hash, std::hash<std::string>()(cn->name_hint));
return hash;
}
size_t VisitType_(const TypeCallNode* tcn) final {
size_t hash = std::hash<std::string>()(TypeCallNode::_type_key);
hash = Combine(hash, TypeHash(tcn->func));
for (const auto& t : tcn->args) {
hash = Combine(hash, TypeHash(t));
}
return hash;
}
size_t VisitType_(const TypeDataNode* tdn) final {
size_t hash = std::hash<std::string>()(TypeDataNode::_type_key);
hash = Combine(hash, TypeHash(tdn->header));
for (const auto& tv : tdn->type_vars) {
hash = Combine(hash, TypeHash(tv));
}
for (const auto& cn : tdn->constructors) {
hash = Combine(hash, ExprHash(cn));
}
return hash;
}
size_t VisitType_(const GlobalTypeVarNode* tvn) final {
return BindVar(GetRef<GlobalTypeVar>(tvn));
}
size_t PatternHash(const Pattern& p) {
return VisitPattern(p);
}
size_t VisitPattern_(const PatternConstructorNode* pcn) final {
size_t hash = std::hash<std::string>()(PatternConstructorNode::_type_key);
hash = Combine(hash, ExprHash(pcn->constructor));
for (const auto& p : pcn->patterns) {
hash = Combine(hash, PatternHash(p));
}
return hash;
}
size_t VisitPattern_(const PatternVarNode* pvn) final {
size_t hash = std::hash<std::string>()(PatternVarNode::_type_key);
hash = Combine(hash, BindVar(pvn->var));
return hash;
}
size_t VisitPattern_(const PatternWildcardNode* pwn) final {
size_t hash = std::hash<std::string>()(PatternWildcardNode::_type_key);
return hash;
}
private:
// renaming of NodeRef to indicate two nodes equals to each other
std::unordered_map<NodeRef, size_t, NodeHash, NodeEqual> hash_map_;
......
......@@ -13,18 +13,28 @@ namespace relay {
using tvm::IRPrinter;
using namespace runtime;
Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs) {
auto n = make_node<ModuleNode>();
n->functions = std::move(global_funcs);
n->type_definitions = std::move(global_type_defs);
for (const auto& kv : n->functions) {
// set gloval var map
// set global var map
CHECK(!n->global_var_map_.count(kv.first->name_hint))
<< "Duplicate global function name " << kv.first->name_hint;
<< "Duplicate global function name " << kv.first->name_hint;
n->global_var_map_.Set(kv.first->name_hint, kv.first);
}
n->entry_func = GlobalVarNode::make("main");
for (const auto& kv : n->type_definitions) {
// set global typevar map
CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint))
<< "Duplicate global type definition name " << kv.first->var->name_hint;
n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first);
}
return Module(n);
}
......@@ -51,6 +61,13 @@ void ModuleNode::AddUnchecked(const GlobalVar& var,
global_var_map_.Set(var->name_hint, var);
}
GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) {
auto it = global_type_var_map_.find(name);
CHECK(it != global_type_var_map_.end())
<< "Cannot find global type var " << name << " in the Module";
return (*it).second;
}
void ModuleNode::Add(const GlobalVar& var,
const Function& func,
bool update) {
......@@ -69,6 +86,22 @@ void ModuleNode::Add(const GlobalVar& var,
AddUnchecked(var, checked_func);
}
void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) {
this->type_definitions.Set(var, type);
// set global type var map
CHECK(!global_type_var_map_.count(var->var->name_hint))
<< "Duplicate global type definition name " << var->var->name_hint;
global_type_var_map_.Set(var->var->name_hint, var);
for (size_t i = 0; i < type->constructors.size(); ++i) {
type->constructors[i]->tag = i;
}
// need to kind check at the end because the check can look up
// a definition potentially
CHECK(KindCheck(type, GetRef<Module>(this)) == Kind::kTypeData)
<< "Invalid or malformed typedata given to module: " << type;
}
void ModuleNode::Update(const GlobalVar& var, const Function& func) {
this->Add(var, func, true);
}
......@@ -92,6 +125,18 @@ Function ModuleNode::Lookup(const std::string& name) {
return this->Lookup(id);
}
TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) {
auto it = type_definitions.find(var);
CHECK(it != type_definitions.end())
<< "There is no definition of " << var->var->name_hint;
return (*it).second;
}
TypeData ModuleNode::LookupDef(const std::string& name) {
GlobalTypeVar id = this->GetGlobalTypeVar(name);
return this->LookupDef(id);
}
void ModuleNode::Update(const Module& mod) {
for (auto pair : mod->functions) {
this->Update(pair.first, pair.second);
......@@ -101,7 +146,7 @@ void ModuleNode::Update(const Module& mod) {
Module ModuleNode::FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs) {
auto mod = ModuleNode::make(global_funcs);
auto mod = ModuleNode::make(global_funcs, {});
auto func_node = expr.as<FunctionNode>();
Function func;
if (func_node) {
......@@ -117,21 +162,33 @@ TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_API("relay._make.Module")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ModuleNode::make(args[0]);
*ret = ModuleNode::make(args[0], args[1]);
});
TVM_REGISTER_API("relay._module.Module_Add")
TVM_REGISTER_API("relay._make.Module_Add")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
mod->Add(args[1], args[2], args[3]);
});
TVM_REGISTER_API("relay._module.Module_AddDef")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
mod->AddDef(args[1], args[2]);
});
TVM_REGISTER_API("relay._module.Module_GetGlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
*ret = mod->GetGlobalVar(args[1]);
});
TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
*ret = mod->GetGlobalTypeVar(args[1]);
});
TVM_REGISTER_API("relay._module.Module_Lookup")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
......@@ -143,8 +200,21 @@ TVM_REGISTER_API("relay._module.Module_Lookup_str")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
std::string var_name = args[1];
auto var = mod->GetGlobalVar(var_name);
*ret = mod->Lookup(var);
*ret = mod->Lookup(var_name);
});
TVM_REGISTER_API("relay._module.Module_LookupDef")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
GlobalTypeVar var = args[1];
*ret = mod->LookupDef(var);
});
TVM_REGISTER_API("relay._module.Module_LookupDef_str")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
std::string var_name = args[1];
*ret = mod->LookupDef(var_name);
});
TVM_REGISTER_API("relay._module.Module_Update")
......
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/pattern_functor.cc
* \brief Implementations of visitors and mutators for ADT patterns.
*/
#include <tvm/relay/pattern_functor.h>
namespace tvm {
namespace relay {
Pattern PatternMutator::Mutate(const Pattern& pat) {
return (*this)(pat);
}
Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) {
return GetRef<Pattern>(op);
}
Pattern PatternMutator::VisitPattern_(const PatternVarNode* op) {
return PatternVarNode::make(VisitVar(op->var));
}
Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) {
std::vector<Pattern> pat;
for (const auto& p : op->patterns) {
pat.push_back(VisitPattern(p));
}
return PatternConstructorNode::make(VisitConstructor(op->constructor), pat);
}
Type PatternMutator::VisitType(const Type& t) {
return t;
}
Var PatternMutator::VisitVar(const Var& v) {
if (var_map_.count(v) == 0) {
var_map_.insert(std::pair<Var, Var>(v,
VarNode::make(v->name_hint(),
VisitType(v->type_annotation))));
}
return var_map_.at(v);
}
Constructor PatternMutator::VisitConstructor(const Constructor& v) {
return v;
}
void PatternVisitor::VisitPattern_(const PatternWildcardNode* op) { }
void PatternVisitor::VisitPattern_(const PatternVarNode* op) {
VisitVar(op->var);
}
void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) {
VisitConstructor(op->constructor);
for (const auto& p : op->patterns) {
VisitPattern(p);
}
}
void PatternVisitor::VisitType(const Type& t) { }
void PatternVisitor::VisitVar(const Var& v) {
VisitType(v->type_annotation);
}
void PatternVisitor::VisitConstructor(const Constructor& c) {
for (const auto& inp : c->inputs) {
VisitType(inp);
}
}
} // namespace relay
} // namespace tvm
......@@ -5,6 +5,7 @@
*/
#include <tvm/relay/module.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <sstream>
#include "type_functor.h"
#include "../../lang/attr_functor.h"
......@@ -23,6 +24,12 @@ struct TextValue {
TextValue() {}
// constructor
explicit TextValue(std::string name) : name(name) {}
TextValue operator+(const TextValue& rhs) const {
return TextValue(name + rhs.name);
}
TextValue operator+(const std::string& str) const {
return TextValue(name + str);
}
};
// operator overloading
......@@ -128,6 +135,7 @@ class TextMetaDataContext {
class TextPrinter :
public ExprFunctor<TextValue(const Expr&)>,
public PatternFunctor<TextValue(const Pattern&)>,
public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*)
public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*)
public:
......@@ -213,6 +221,9 @@ class TextPrinter :
memo_[expr] = val;
return val;
}
TextValue GetValue(const Pattern& p) {
return this->VisitPattern(p);
}
//------------------------------------
// Overload of Expr printing functions
//------------------------------------
......@@ -391,6 +402,36 @@ class TextPrinter :
return id;
}
TextValue VisitExpr_(const MatchNode* op) final {
TextValue data = GetValue(op->data);
this->PrintIndent();
TextValue id = this->AllocTempVar();
stream_ << id << " = " << "Match " << data << " with";
this->PrintEndInst("\n");
for (const auto& c : op->clauses) {
this->PrintIndent();
stream_ << GetValue(c->lhs) << " to " << GetValue(c->rhs);
this->PrintEndInst("\n");
}
return id;
}
TextValue VisitPattern_(const PatternConstructorNode* p) final {
TextValue ret(p->constructor->name_hint + "(");
for (const Pattern& pat : p->patterns) {
ret = ret + " " + GetValue(pat);
}
return ret + ")";
}
TextValue VisitPattern_(const PatternVarNode* pv) final {
return GetValue(pv->var);
}
TextValue VisitExpr_(const ConstructorNode* n) final {
return TextValue(n->name_hint);
}
/*!
* \brief Print the type to os
* \param type The type to be printed.
......@@ -437,6 +478,18 @@ class TextPrinter :
VisitTypeDefault_(node, os);
}
void VisitType_(const TypeCallNode* node, std::ostream& os) final {
os << node->func << "(" << node->args << ")";
}
void VisitType_(const GlobalTypeVarNode* node, std::ostream& os) final {
VisitTypeDefault_(node, os);
}
void VisitType_(const TypeDataNode* node, std::ostream& os) final {
VisitTypeDefault_(node, os);
}
void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*)
// by default always print as meta-data
os << meta_.GetMetaNode(GetRef<NodeRef>(node));
......
......@@ -48,7 +48,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
});
TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) {
TypeVar TypeVarNode::make(std::string name, Kind kind) {
NodePtr<TypeVarNode> n = make_node<TypeVarNode>();
n->var = tvm::Var(name);
n->kind = std::move(kind);
......@@ -61,7 +61,7 @@ TVM_REGISTER_API("relay._make.TypeVar")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[1];
*ret =
TypeVarNode::make(args[0], static_cast<TypeVarNode::Kind>(kind));
TypeVarNode::make(args[0], static_cast<Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
......@@ -71,7 +71,50 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< node->kind << ")";
});
IncompleteType IncompleteTypeNode::make(TypeVarNode::Kind kind) {
GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) {
NodePtr<GlobalTypeVarNode> n = make_node<GlobalTypeVarNode>();
n->var = tvm::Var(name);
n->kind = std::move(kind);
return GlobalTypeVar(n);
}
TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
TVM_REGISTER_API("relay._make.GlobalTypeVar")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[1];
*ret = GlobalTypeVarNode::make(args[0], static_cast<Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const GlobalTypeVarNode *node,
tvm::IRPrinter *p) {
p->stream << "GlobalTypeVarNode(" << node->var->name_hint << ", "
<< node->kind << ")";
});
TypeCall TypeCallNode::make(Type func, tvm::Array<Type> args) {
NodePtr<TypeCallNode> n = make_node<TypeCallNode>();
n->func = std::move(func);
n->args = std::move(args);
return TypeCall(n);
}
TVM_REGISTER_NODE_TYPE(TypeCallNode);
TVM_REGISTER_API("relay._make.TypeCall")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TypeCallNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeCallNode>([](const TypeCallNode* node,
tvm::IRPrinter* p) {
p->stream << "TypeCallNode(" << node->func << ", "
<< node->args << ")";
});
IncompleteType IncompleteTypeNode::make(Kind kind) {
auto n = make_node<IncompleteTypeNode>();
n->kind = std::move(kind);
return IncompleteType(n);
......@@ -82,7 +125,7 @@ TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_API("relay._make.IncompleteType")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[0];
*ret = IncompleteTypeNode::make(static_cast<TypeVarNode::Kind>(kind));
*ret = IncompleteTypeNode::make(static_cast<Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
......
......@@ -48,6 +48,29 @@ void TypeVisitor::VisitType_(const TypeRelationNode* op) {
}
}
void TypeVisitor::VisitType_(const GlobalTypeVarNode* op) {
}
void TypeVisitor::VisitType_(const TypeCallNode* op) {
this->VisitType(op->func);
for (const Type& t : op->args) {
this->VisitType(t);
}
}
void TypeVisitor::VisitType_(const TypeDataNode* op) {
this->VisitType(op->header);
for (const auto& v : op->type_vars) {
this->VisitType(v);
}
for (const auto& c : op->constructors) {
this->VisitType(c->belong_to);
for (const auto& t : c->inputs) {
this->VisitType(t);
}
}
}
// Type Mutator.
Array<Type> TypeMutator::MutateArray(Array<Type> arr) {
......@@ -139,6 +162,24 @@ Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) {
}
}
Type TypeMutator::VisitType_(const GlobalTypeVarNode* op) {
return GetRef<Type>(op);
}
Type TypeMutator::VisitType_(const TypeCallNode* op) {
Type new_func = VisitType(op->func);
Array<Type> new_args = MutateArray(op->args);
if (new_args.same_as(op->args) && new_func.same_as(op->func)) {
return GetRef<TypeCall>(op);
} else {
return TypeCallNode::make(new_func, new_args);
}
}
Type TypeMutator::VisitType_(const TypeDataNode* op) {
return GetRef<Type>(op);
}
// Implements bind.
class TypeBinder : public TypeMutator {
public:
......
......@@ -8,6 +8,7 @@
#include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/adt.h>
#include <string>
#include <vector>
......@@ -69,6 +70,10 @@ class TypeFunctor<R(const Type& n, Args...)> {
virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const GlobalTypeVarNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitTypeDefault_(const Node* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->type_key();
throw; // unreachable, written to stop compiler warning
......@@ -87,6 +92,9 @@ class TypeFunctor<R(const Type& n, Args...)> {
RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(GlobalTypeVarNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
return vtable;
}
};
......@@ -103,6 +111,9 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> {
void VisitType_(const TupleTypeNode* op) override;
void VisitType_(const TypeRelationNode* op) override;
void VisitType_(const RefTypeNode* op) override;
void VisitType_(const GlobalTypeVarNode* op) override;
void VisitType_(const TypeCallNode* op) override;
void VisitType_(const TypeDataNode* op) override;
};
// Mutator that transform a type to another one.
......@@ -115,6 +126,9 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> {
Type VisitType_(const TupleTypeNode* op) override;
Type VisitType_(const TypeRelationNode* type_rel) override;
Type VisitType_(const RefTypeNode* op) override;
Type VisitType_(const GlobalTypeVarNode* op) override;
Type VisitType_(const TypeCallNode* op) override;
Type VisitType_(const TypeDataNode* op) override;
private:
Array<Type> MutateArray(Array<Type> arr);
......
......@@ -296,6 +296,15 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
void VisitExpr_(const MatchNode* op) final {
this->Update(op->data, nullptr, kOpaque);
for (const Clause& c : op->clauses) {
this->Update(c->rhs, nullptr, kOpaque);
}
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
};
IndexedForwardGraph IndexedForwardGraph::Create(
......
......@@ -14,106 +14,160 @@
* contains a data type such as `int`, `float`, `uint`.
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/error.h>
#include "../ir/type_functor.h"
namespace tvm {
namespace relay {
using namespace tvm::runtime;
using Kind = TypeVarNode::Kind;
struct KindChecker : TypeVisitor {
bool valid;
struct KindChecker : TypeFunctor<Kind(const Type&)> {
const Module& mod;
ErrorReporter err_reporter;
KindChecker() : valid(true) {}
explicit KindChecker(const Module& mod) : mod(mod), err_reporter() {}
// checks if t is an incomplete node of kind k or a type param of kind k
bool MatchKind(const Type& t, Kind k) {
if (const IncompleteTypeNode* tv = t.as<IncompleteTypeNode>()) {
return tv->kind == k;
}
void ReportFatalError(const Error& err) {
this->err_reporter.Report(err);
this->err_reporter.RenderErrors(mod);
}
if (const TypeVarNode* tp = t.as<TypeVarNode>()) {
return tp->kind == k;
void CheckKindMatches(const Type& t, const Type& outer,
Kind expected, const std::string& description) {
Kind k = this->VisitType(t);
if (k != expected) {
ReportFatalError(RELAY_ERROR("Incorrect kind for a " << description
<< ". Type " << t << " inside " << outer
<< " is of kind " << k
<< " but was expected to be "
<< expected));
}
}
return false;
Kind VisitType_(const IncompleteTypeNode* op) override {
return op->kind;
}
bool IsTypeKind(const Type& t) {
if (MatchKind(t, Kind::kType)) {
return true;
}
Kind VisitType_(const TypeVarNode* op) override {
return op->kind;
}
Kind VisitType_(const GlobalTypeVarNode* op) override {
return op->kind;
}
return t.as_derived<BaseTensorTypeNode>() || t.as<TupleTypeNode>() || t.as<FuncTypeNode>();
Kind VisitType_(const TensorTypeNode* op) override {
return Kind::kType;
}
void VisitType_(const TupleTypeNode* op) override {
Kind VisitType_(const TupleTypeNode* op) override {
// tuples should only contain normal types
for (const Type& t : op->fields) {
this->VisitType(t);
valid = valid && IsTypeKind(t);
if (!valid) {
return;
}
CheckKindMatches(t, GetRef<TupleType>(op), Kind::kType,
"tuple member");
}
return Kind::kType;
}
void VisitType_(const FuncTypeNode* op) override {
Kind VisitType_(const FuncTypeNode* op) override {
// Func types should only take normal types for arguments
// and only return a normal type. They should also have
// well-formed constraints
FuncType ft = GetRef<FuncType>(op);
for (const Type& t : op->arg_types) {
this->VisitType(t);
valid = valid && IsTypeKind(t);
if (!valid) {
return;
}
CheckKindMatches(t, ft, Kind::kType, "function type parameter");
}
CheckKindMatches(ft->ret_type, ft, Kind::kType, "function return type");
for (const TypeConstraint& tc : op->type_constraints) {
this->VisitType(tc);
if (!valid) {
return;
}
CheckKindMatches(tc, ft, Kind::kConstraint, "function type constraint");
}
this->VisitType(op->ret_type);
valid = valid && IsTypeKind(op->ret_type);
return Kind::kType;
}
void VisitType_(const RefTypeNode* op) override {
// tuples should only contain normal types
this->VisitType(op->value);
valid = valid && IsTypeKind(op->value);
Kind VisitType_(const RefTypeNode* op) override {
// ref types should only contain normal types
RefType rt = GetRef<RefType>(op);
CheckKindMatches(op->value, rt, Kind::kType, "ref contents");
return Kind::kType;
}
void VisitType_(const TypeRelationNode* op) override {
Kind VisitType_(const TypeRelationNode* op) override {
// arguments to type relation should be normal types
for (const Type& t : op->args) {
this->VisitType(t);
valid = valid && IsTypeKind(t);
if (!valid) {
return;
CheckKindMatches(t, GetRef<TypeRelation>(op), Kind::kType,
"argument to type relation");
}
return Kind::kConstraint;
}
Kind VisitType_(const TypeCallNode* op) override {
// type call func should be a global type var, args should be type
TypeCall tc = GetRef<TypeCall>(op);
const auto* gtv = op->func.as<GlobalTypeVarNode>();
if (gtv == nullptr) {
ReportFatalError(RELAY_ERROR("The callee in " << tc
<< " is not a global type var, but is " << op->func));
}
CheckKindMatches(op->func, tc, Kind::kAdtHandle, "type call function");
for (const Type& t : op->args) {
CheckKindMatches(t, tc, Kind::kType, "type call argument");
}
// finally we need to check the module to check the number of type params
auto var = GetRef<GlobalTypeVar>(gtv);
auto data = mod->LookupDef(var);
if (data->type_vars.size() != op->args.size()) {
ReportFatalError(RELAY_ERROR("Expected " << data->type_vars.size() << "arguments for " << tc
<< "; got " << op->args.size()));
}
return Kind::kType;
}
Kind VisitType_(const TypeDataNode* op) override {
// Constructors can reference the header var, but no other GlobalTypeVars.
// In theory, a TypeData could be nested, so the header scope
// should be tracked recursively, but it is unclear that we need
// to support it.
TypeData td = GetRef<TypeData>(op);
CheckKindMatches(op->header, td, Kind::kAdtHandle, "type data header");
for (const auto& var : op->type_vars) {
CheckKindMatches(var, td, Kind::kType, "ADT type var");
}
for (const auto& con : op->constructors) {
if (!con->belong_to.same_as(op->header)) {
ReportFatalError(RELAY_ERROR(con << " has header " << con->belong_to
<< " but " << op << "has header " << op->header));
}
for (const Type& t : con->inputs) {
CheckKindMatches(t, td, Kind::kType, "ADT constructor input");
}
}
return Kind::kTypeData;
}
bool Check(const Type& t) {
this->VisitType(t);
return valid;
Kind Check(const Type& t) {
return this->VisitType(t);
}
};
bool KindCheck(const Type& t, const Module& mod) {
KindChecker kc;
Kind KindCheck(const Type& t, const Module& mod) {
KindChecker kc(mod);
return kc.Check(t);
}
TVM_REGISTER_API("relay._ir_pass.check_kind")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args.size() == 1) {
*ret = KindCheck(args[0], ModuleNode::make({}));
*ret = KindCheck(args[0], ModuleNode::make({}, {}));
} else {
*ret = KindCheck(args[0], args[1]);
}
......
......@@ -62,7 +62,7 @@ class LetList {
* \return a Var that hold the inserted expr.
*/
Var Push(Expr expr) {
return Push(IncompleteTypeNode::make(TypeVarNode::kType), expr);
return Push(IncompleteTypeNode::make(Kind::kType), expr);
}
/*!
......
......@@ -274,7 +274,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
}
Expr VisitExpr(const Expr& e) {
Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(TypeVarNode::kType));
Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(Kind::kType));
return this->VisitExpr(e, v);
}
......
......@@ -189,6 +189,20 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
return RefTypeNode::make(Unify(op->value, rtn->value));
}
Type VisitType_(const TypeCallNode* op, const Type& tn) override {
const auto* tcn = tn.as<TypeCallNode>();
if (!tcn || tcn->args.size() != op->args.size()) {
return Type();
}
Type func = Unify(op->func, tcn->func);
tvm::Array<Type> args;
for (size_t i = 0; i < op->args.size(); i++) {
args.push_back(Unify(op->args[i], tcn->args[i]));
}
return TypeCallNode::make(func, args);
}
private:
TypeSolver* solver_;
};
......@@ -266,6 +280,16 @@ class TypeSolver::Propagator : public TypeFunctor<void(const Type&)> {
}
}
void VisitType_(const TypeCallNode* op) override {
TypeCall tc = GetRef<TypeCall>(op);
UpdateRelSet(tc);
Propagate(tc->func);
for (auto arg : tc->args) {
Propagate(arg);
}
}
private:
TypeSolver* solver_;
const std::unordered_set<RelationNode*>* rels_;
......@@ -494,7 +518,7 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver")
} else if (name == "AddConstraint") {
return TypedPackedFunc<void(TypeConstraint)>([solver](TypeConstraint c) {
Expr e = VarNode::make("dummy_var",
IncompleteTypeNode::make(TypeVarNode::Kind::kType));
IncompleteTypeNode::make(Kind::kType));
return solver->AddConstraint(c, e);
});
} else {
......
......@@ -7,6 +7,7 @@
*/
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "../ir/type_functor.h"
namespace tvm {
......@@ -29,7 +30,7 @@ class TypeVarTVisitor : public TypeVisitor {
TypeVarTVisitor(
InsertionSet<TypeVar>* type_vars,
InsertionSet<TypeVar>* bound_type_vars)
: type_vars_(type_vars), bound_type_vars_(bound_type_vars) { }
: type_vars_(type_vars), bound_type_vars_(bound_type_vars) { }
void VisitType_(const TypeVarNode* tp) final {
TypeVar var = GetRef<TypeVar>(tp);
......@@ -51,6 +52,8 @@ class TypeVarTVisitor : public TypeVisitor {
class TypeVarEVisitor : private ExprVisitor {
public:
explicit TypeVarEVisitor(const Module& mod) : mod_(mod) {}
Array<TypeVar> CollectFree() {
Array<TypeVar> ret;
for (const auto& v : type_vars_.data) {
......@@ -115,6 +118,16 @@ class TypeVarEVisitor : private ExprVisitor {
ExprVisitor::VisitExpr_(f);
}
void VisitExpr_(const ConstructorNode* cn) final {
// for constructors, type vars will be bound in the module
auto data = mod_->LookupDef(cn->belong_to);
for (const auto& tv : data->type_vars) {
type_vars_.Insert(tv);
bound_type_vars_.Insert(tv);
}
ExprVisitor::VisitExpr_(cn);
}
void VisitType(const Type& t) final {
TypeVarTVisitor(&type_vars_, &bound_type_vars_)
.VisitType(t);
......@@ -123,9 +136,10 @@ class TypeVarEVisitor : private ExprVisitor {
private:
InsertionSet<TypeVar> type_vars_;
InsertionSet<TypeVar> bound_type_vars_;
const Module& mod_;
};
class VarVisitor : protected ExprVisitor {
class VarVisitor : protected ExprVisitor, protected PatternVisitor {
public:
Array<Var> Free(const Expr& expr) {
this->VisitExpr(expr);
......@@ -178,33 +192,41 @@ class VarVisitor : protected ExprVisitor {
VisitExpr(op->body);
}
void VisitPattern(const Pattern& p) final {
PatternVisitor::VisitPattern(p);
}
void VisitPattern_(const PatternVarNode* op) final {
MarkBounded(op->var);
}
private:
InsertionSet<Var> vars_;
InsertionSet<Var> bound_vars_;
};
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr) {
return TypeVarEVisitor().Free(expr);
tvm::Array<TypeVar> FreeTypeVars(const Expr& expr, const Module& mod) {
return TypeVarEVisitor(mod).Free(expr);
}
tvm::Array<TypeVar> FreeTypeVars(const Type& type) {
return TypeVarEVisitor().Free(type);
tvm::Array<TypeVar> FreeTypeVars(const Type& type, const Module& mod) {
return TypeVarEVisitor(mod).Free(type);
}
tvm::Array<TypeVar> BoundTypeVars(const Expr& expr) {
return TypeVarEVisitor().Bound(expr);
tvm::Array<TypeVar> BoundTypeVars(const Expr& expr, const Module& mod) {
return TypeVarEVisitor(mod).Bound(expr);
}
tvm::Array<TypeVar> BoundTypeVars(const Type& type) {
return TypeVarEVisitor().Bound(type);
tvm::Array<TypeVar> BoundTypeVars(const Type& type, const Module& mod) {
return TypeVarEVisitor(mod).Bound(type);
}
tvm::Array<TypeVar> AllTypeVars(const Expr& expr) {
return TypeVarEVisitor().All(expr);
tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const Module& mod) {
return TypeVarEVisitor(mod).All(expr);
}
tvm::Array<TypeVar> AllTypeVars(const Type& type) {
return TypeVarEVisitor().All(type);
tvm::Array<TypeVar> AllTypeVars(const Type& type, const Module& mod) {
return TypeVarEVisitor(mod).All(type);
}
tvm::Array<Var> FreeVars(const Expr& expr) {
......@@ -237,30 +259,33 @@ TVM_REGISTER_API("relay._ir_pass.all_vars")
TVM_REGISTER_API("relay._ir_pass.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0];
Module mod = args[1];
if (x.as_derived<TypeNode>()) {
*ret = FreeTypeVars(Downcast<Type>(x));
*ret = FreeTypeVars(Downcast<Type>(x), mod);
} else {
*ret = FreeTypeVars(Downcast<Expr>(x));
*ret = FreeTypeVars(Downcast<Expr>(x), mod);
}
});
TVM_REGISTER_API("relay._ir_pass.bound_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0];
Module mod = args[1];
if (x.as_derived<TypeNode>()) {
*ret = BoundTypeVars(Downcast<Type>(x));
*ret = BoundTypeVars(Downcast<Type>(x), mod);
} else {
*ret = BoundTypeVars(Downcast<Expr>(x));
*ret = BoundTypeVars(Downcast<Expr>(x), mod);
}
});
TVM_REGISTER_API("relay._ir_pass.all_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[0];
Module mod = args[1];
if (x.as_derived<TypeNode>()) {
*ret = AllTypeVars(Downcast<Type>(x));
*ret = AllTypeVars(Downcast<Type>(x), mod);
} else {
*ret = AllTypeVars(Downcast<Expr>(x));
*ret = AllTypeVars(Downcast<Expr>(x), mod);
}
});
......
......@@ -13,7 +13,10 @@ TEST(Relay, SelfReference) {
auto y = relay::VarNode::make("y", tensor_type);
auto call = relay::CallNode::make(f, Array<relay::Expr>{ y });
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
auto type_fx = relay::InferType(fx, relay::ModuleNode::make(Map<relay::GlobalVar, relay::Function>{}));
auto empty_module =
relay::ModuleNode::make(Map<relay::GlobalVar, relay::Function>{},
Map<relay::GlobalTypeVar, relay::TypeData>{});
auto type_fx = relay::InferType(fx, empty_module);
auto expected = relay::FuncTypeNode::make(tvm::Array<relay::Type>{ tensor_type }, tensor_type, {}, {});
CHECK(AlphaEqual(type_fx->checked_type(), expected));
......
......@@ -171,6 +171,29 @@ def test_type_relation_alpha_equal():
assert bigger != diff_num_inputs
def test_type_call_alpha_equal():
h1 = relay.GlobalTypeVar("h1")
h2 = relay.GlobalTypeVar("h2")
t1 = relay.TensorType((1, 2), "float32")
t2 = relay.TensorType((1, 2, 3), "float32")
t3 = relay.TensorType((1, 2, 3, 4), "float32")
t4 = relay.TensorType((), "float32")
tc = relay.TypeCall(h1, [t1, t2, t3])
same = relay.TypeCall(h1, [t1, t2, t3])
different_func = relay.TypeCall(h2, [t1, t2, t3])
different_arg = relay.TypeCall(h1, [t1, t2, t4])
fewer_args = relay.TypeCall(h1, [t1, t2])
more_args = relay.TypeCall(h1, [t1, t2, t3, t4])
different_order_args = relay.TypeCall(h1, [t3, t2, t1])
assert tc == same
assert tc != different_func
assert tc != fewer_args
assert tc != more_args
assert tc != different_order_args
def test_constant_alpha_equal():
x = relay.const(1)
......@@ -453,6 +476,79 @@ def test_if_alpha_equal():
assert not alpha_equal(if_sample, different_false)
def test_constructor_alpha_equal():
# smoke test: it should be pointer equality
mod = relay.Module()
p = relay.prelude.Prelude(mod)
assert alpha_equal(p.nil, p.nil)
assert alpha_equal(p.cons, p.cons)
assert not alpha_equal(p.nil, p.cons)
def test_match_alpha_equal():
mod = relay.Module()
p = relay.prelude.Prelude(mod)
x = relay.Var('x')
y = relay.Var('y')
nil_case = relay.Clause(relay.PatternConstructor(p.nil), p.nil())
cons_case = relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternVar(x),
relay.PatternVar(y)]),
p.cons(x, y))
z = relay.Var('z')
a = relay.Var('a')
equivalent_cons = relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternVar(z),
relay.PatternVar(a)]),
p.cons(z, a))
data = p.cons(p.z(), p.cons(p.z(), p.nil()))
match = relay.Match(data, [nil_case, cons_case])
equivalent = relay.Match(data, [nil_case, equivalent_cons])
empty = relay.Match(data, [])
no_cons = relay.Match(data, [nil_case])
no_nil = relay.Match(data, [cons_case])
different_data = relay.Match(p.nil(), [nil_case, cons_case])
different_order = relay.Match(data, [cons_case, nil_case])
different_nil = relay.Match(data, [
relay.Clause(relay.PatternConstructor(p.nil), p.cons(p.nil(), p.nil())),
cons_case
])
different_cons = relay.Match(data, [
nil_case,
relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternWildcard(),
relay.PatternWildcard()]),
p.nil())
])
another_case = relay.Match(data, [
nil_case,
cons_case,
relay.Clause(relay.PatternWildcard(), p.nil())
])
wrong_constructors = relay.Match(data, [
relay.Clause(relay.PatternConstructor(p.z), p.nil()),
relay.Clause(relay.PatternConstructor(p.s, [relay.PatternVar(x)]),
p.cons(x, p.nil()))
])
assert alpha_equal(match, match)
assert alpha_equal(match, equivalent)
assert not alpha_equal(match, no_cons)
assert not alpha_equal(match, no_nil)
assert not alpha_equal(match, empty)
assert not alpha_equal(match, different_data)
assert not alpha_equal(match, different_order)
assert not alpha_equal(match, different_nil)
assert not alpha_equal(match, different_cons)
assert not alpha_equal(match, another_case)
assert not alpha_equal(match, wrong_constructors)
def test_op_alpha_equal():
# only checks names
op1 = relay.op.get("add")
......@@ -491,6 +587,7 @@ if __name__ == "__main__":
test_func_type_alpha_equal()
test_tuple_type_alpha_equal()
test_type_relation_alpha_equal()
test_type_call_alpha_equal()
test_constant_alpha_equal()
test_global_var_alpha_equal()
test_tuple_alpha_equal()
......@@ -499,6 +596,8 @@ if __name__ == "__main__":
test_call_alpha_equal()
test_let_alpha_equal()
test_if_alpha_equal()
test_constructor_alpha_equal()
test_match_alpha_equal()
test_op_alpha_equal()
test_var_alpha_equal()
test_graph_equal()
import tvm
from tvm import relay
from tvm.relay.ir_pass import check_kind
from nose.tools import raises
def test_typevar_kind():
# returns the same kind
tp1 = relay.TypeVar('tp1', relay.Kind.Type)
tp2 = relay.TypeVar('tp2', relay.Kind.Shape)
tp3 = relay.TypeVar('tp3', relay.Kind.Constraint)
assert check_kind(tp1) == relay.Kind.Type
assert check_kind(tp2) == relay.Kind.Shape
assert check_kind(tp3) == relay.Kind.Constraint
def test_tuple_kind():
# only contain type kinds
......@@ -10,7 +23,7 @@ def test_tuple_kind():
fields = tvm.convert([tp, tf, tt])
tup_ty = relay.TupleType(fields)
assert check_kind(tup_ty)
assert check_kind(tup_ty) == relay.Kind.Type
def test_func_kind():
......@@ -30,7 +43,20 @@ def test_func_kind():
ret_type = relay.TupleType(tvm.convert([tp2, tensor_type]))
tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
assert check_kind(tf)
assert check_kind(tf) == relay.Kind.Type
def test_ref_kind():
# only contain type kinds
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
ft = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([]))
rt1 = relay.RefType(tt)
assert check_kind(rt1) == relay.Kind.Type
rt2 = relay.RefType(ft)
assert check_kind(rt2) == relay.Kind.Type
rt3 = relay.RefType(relay.TupleType([rt1, rt2]))
assert check_kind(rt3) == relay.Kind.Type
def test_relation_kind():
......@@ -41,9 +67,35 @@ def test_relation_kind():
args = tvm.convert([tf, tt, tp])
tr = relay.TypeRelation(None, args, 2, None)
assert check_kind(tr)
assert check_kind(tr) == relay.Kind.Constraint
def test_global_typevar_kind():
v1 = relay.GlobalTypeVar('gtv1', relay.Kind.AdtHandle)
v2 = relay.GlobalTypeVar('gtv2', relay.Kind.Type)
assert check_kind(v1) == relay.Kind.AdtHandle
assert check_kind(v2) == relay.Kind.Type
def test_typecall_kind():
gtv = relay.GlobalTypeVar('gtv')
mod = relay.Module()
data = relay.TypeData(gtv, [], [])
mod[gtv] = data
empty_call = relay.TypeCall(gtv, [])
assert check_kind(empty_call, mod) == relay.Kind.Type
new_mod = relay.Module()
tv = relay.TypeVar('tv')
new_data = relay.TypeData(gtv, [tv], [])
new_mod[gtv] = new_data
call = relay.TypeCall(gtv, [relay.TupleType([])])
assert check_kind(call, new_mod) == relay.Kind.Type
@raises(tvm._ffi.base.TVMError)
def test_invalid_tuple_kind():
tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tp2 = relay.TypeVar('tp2', relay.Kind.BaseType)
......@@ -51,9 +103,10 @@ def test_invalid_tuple_kind():
fields = tvm.convert([tp1, tp2, tp3])
tup_ty = relay.TupleType(fields)
assert not check_kind(tup_ty)
check_kind(tup_ty)
@raises(tvm._ffi.base.TVMError)
def test_invalid_func_kind():
tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tp2 = relay.TypeVar('tp2', relay.Kind.BaseType)
......@@ -65,51 +118,98 @@ def test_invalid_func_kind():
ret_type = tp3
tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
assert not check_kind(tf)
check_kind(tf)
@raises(tvm._ffi.base.TVMError)
def test_invalid_ref_kind():
tp = relay.TypeVar('tp', relay.Kind.Shape)
rt = relay.RefType(tp)
check_kind(rt)
@raises(tvm._ffi.base.TVMError)
def test_invalid_relation_kind():
tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tp2 = relay.TypeVar('tp2', relay.Kind.BaseType)
tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar)
args = tvm.convert([tp1, tp2, tp3])
tr = relay.TypeRelation(None, args, 2, None)
assert not check_kind(tr)
func = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
tr = relay.TypeRelation(func, args, 2, None)
check_kind(tr)
@raises(tvm._ffi.base.TVMError)
def test_typecall_invalid_callee():
# global type var must be an ADT handle
gtv = relay.GlobalTypeVar('v1', relay.Kind.Type)
check_kind(relay.TypeCall(gtv, []))
@raises(tvm._ffi.base.TVMError)
def test_typecall_invalid_args():
# args must all be type kind
mod = relay.Module()
gtv = relay.GlobalTypeVar('v1')
data = relay.TypeData(gtv, [], [])
mod[gtv] = data
check_kind(relay.TypeCall(gtv, [data]))
@raises(tvm._ffi.base.TVMError)
def test_typecall_invalid_num_args():
mod = relay.Module()
gtv = relay.GlobalTypeVar('v1')
tv = relay.TypeVar('tv')
data = relay.TypeData(gtv, [tv], [])
mod[gtv] = data
check_kind(relay.TypeCall(gtv, []))
@raises(tvm._ffi.base.TVMError)
def test_func_with_invalid_ret_type():
tp1 = relay.TypeVar('tp1', relay.Kind.Type)
tp2 = relay.TypeVar('tp2', relay.Kind.Shape)
tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([]))
check_kind(tf)
@raises(tvm._ffi.base.TVMError)
def test_func_with_invalid_arg_types():
tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
tp2 = relay.TypeVar('tp2', relay.Kind.Type)
tf = relay.FuncType(tvm.convert([tp1]), tp2, tvm.convert([tp1, tp2]), tvm.convert([]))
check_kind(tf)
@raises(tvm._ffi.base.TVMError)
def test_func_with_invalid_tuple():
tp1 = relay.TypeVar('tp1', relay.Kind.Shape)
ret_type = relay.TupleType(tvm.convert([tp1, tp1, tp1]))
tf = relay.FuncType(tvm.convert([]), ret_type, tvm.convert([tp1]), tvm.convert([]))
assert not check_kind(tf)
check_kind(tf)
@raises(tvm._ffi.base.TVMError)
def test_func_with_invalid_relation():
tp1 = relay.TypeVar('tp1', relay.Kind.Type)
tp2 = relay.TypeVar('tp2', relay.Kind.Shape)
tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar)
tr = relay.TypeRelation(None, tvm.convert([tp2, tp3]), 1, None)
func = tvm.get_env_func("tvm.relay.type_relation.Identity")
tr = relay.TypeRelation(func, tvm.convert([tp2, tp3]), 1, None)
tf = relay.FuncType(tvm.convert([tp1]), tp1, tvm.convert([tp1, tp2, tp3]), tvm.convert([tr]))
assert not check_kind(tf)
check_kind(tf)
@raises(tvm._ffi.base.TVMError)
def test_tuple_with_invalid_func():
tensor_type = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
......@@ -117,16 +217,23 @@ def test_tuple_with_invalid_func():
tf = relay.FuncType(tvm.convert([]), tp1, tvm.convert([tp1]), tvm.convert([]))
tup_ty = relay.TupleType(tvm.convert([tensor_type, tf]))
assert not check_kind(tup_ty)
check_kind(tup_ty)
if __name__ == "__main__":
test_tuple_kind()
test_func_kind()
test_ref_kind()
test_relation_kind()
test_global_typevar_kind()
test_typecall_kind()
test_invalid_tuple_kind()
test_invalid_func_kind()
test_invalid_ref_kind()
test_invalid_relation_kind()
test_typecall_invalid_callee()
test_typecall_invalid_args()
test_typecall_invalid_num_args()
test_func_with_invalid_ret_type()
test_func_with_invalid_arg_types()
test_func_with_invalid_tuple()
......
......@@ -65,6 +65,40 @@ def test_bound_vars():
assert_vars_match(bound_vars(f2), [x, y])
def test_match_vars():
mod = relay.Module()
p = relay.prelude.Prelude(mod)
x = relay.Var('x')
y = relay.Var('y')
z = relay.Var('z')
match1 = relay.Match(p.nil(), [
relay.Clause(relay.PatternConstructor(p.nil), z),
relay.Clause(relay.PatternConstructor(p.cons,
[relay.PatternVar(x),
relay.PatternVar(y)]),
p.cons(x, y))
])
match2 = relay.Match(p.nil(), [
relay.Clause(relay.PatternConstructor(p.cons, [
relay.PatternWildcard(),
relay.PatternVar(x)
]),
y),
relay.Clause(relay.PatternWildcard(), z)
])
assert_vars_match(bound_vars(match1), [x, y])
assert_vars_match(free_vars(match1), [z])
assert_vars_match(all_vars(match1), [z, x, y])
assert_vars_match(bound_vars(match2), [x])
assert_vars_match(free_vars(match2), [y, z])
assert_vars_match(all_vars(match2), [x, y, z])
def test_bound_type_vars():
a = relay.TypeVar("a")
b = relay.TypeVar("b")
......@@ -127,7 +161,7 @@ def test_all_type_vars():
x = relay.Var("x", a)
y = relay.Var("y", b)
z = relay.Var("z", c)
f1 = relay.Function([x], y, b, [a])
assert_vars_match(all_type_vars(f1), [a, b])
......
......@@ -17,6 +17,16 @@ def assert_has_type(expr, typ, mod=relay.module.Module({})):
checked_type, typ))
# initializes simple ADT for tests
def initialize_box_adt(mod):
box = relay.GlobalTypeVar('box')
tv = relay.TypeVar('tv')
constructor = relay.Constructor('constructor', [tv], box)
data = relay.TypeData(box, [tv], [constructor])
mod[box] = data
return (box, constructor)
def test_monomorphic_let():
"Program: let x = 1; return x"
sb = relay.ScopeBuilder()
......@@ -190,6 +200,69 @@ def test_equal():
assert ft.checked_type == relay.FuncType([relay.scalar_type('int32')], relay.scalar_type('bool'))
def test_constructor_type():
mod = relay.Module()
box, constructor = initialize_box_adt(mod)
a = relay.TypeVar('a')
x = relay.Var('x', a)
ct = relay.ir_pass.infer_type(
relay.Function([x], constructor(x), box(a), [a]), mod)
expected = relay.FuncType([a], box(a), [a])
assert ct.checked_type == expected
def test_constructor_call():
mod = relay.Module()
box, constructor = initialize_box_adt(mod)
box_unit = constructor(relay.Tuple([]))
box_constant = constructor(relay.const(0, 'float32'))
ut = relay.ir_pass.infer_type(box_unit, mod)
ct = relay.ir_pass.infer_type(box_constant, mod)
assert ut.checked_type == box(relay.TupleType([]))
assert ct.checked_type == box(relay.TensorType((), 'float32'))
def test_adt_match():
mod = relay.Module()
box, constructor = initialize_box_adt(mod)
v = relay.Var('v', relay.TensorType((), 'float32'))
match = relay.Match(constructor(relay.const(0, 'float32')),
[relay.Clause(
relay.PatternConstructor(constructor,
[relay.PatternVar(v)]),
relay.Tuple([])),
# redundant but shouldn't matter to typechecking
relay.Clause(relay.PatternWildcard(),
relay.Tuple([]))])
mt = relay.ir_pass.infer_type(match, mod)
assert mt.checked_type == relay.TupleType([])
def test_adt_match_type_annotations():
mod = relay.Module()
box, constructor = initialize_box_adt(mod)
# the only type annotation is inside the match pattern var
# but that should be enough info
tt = relay.TensorType((2, 2), 'float32')
x = relay.Var('x')
mv = relay.Var('mv', tt)
match = relay.Match(constructor(x),
[relay.Clause(
relay.PatternConstructor(constructor,
[relay.PatternVar(mv)]),
relay.Tuple([]))])
func = relay.Function([x], match)
ft = relay.ir_pass.infer_type(func, mod)
assert ft.checked_type == relay.FuncType([tt], relay.TupleType([]))
if __name__ == "__main__":
test_free_expr()
test_dual_op()
......@@ -205,3 +278,6 @@ if __name__ == "__main__":
test_global_var_recursion()
test_equal()
test_ref()
test_constructor_type()
test_constructor_call()
test_adt_match()
......@@ -62,6 +62,30 @@ def test_unify_tuple():
assert unified == tup2
def test_unify_global_type_var():
# should only be able to unify if they're the same
solver = make_solver()
gtv = relay.GlobalTypeVar('gtv')
unified = solver.Unify(gtv, gtv)
assert unified == gtv
def test_unify_typecall():
solver = make_solver()
gtv = relay.GlobalTypeVar('gtv')
# yeah, typecalls are shaped like tuples so the same
# tests work out
t1 = relay.ty.IncompleteType()
t2 = relay.ty.IncompleteType()
t3 = relay.ty.TensorType((10, 20), "float32")
tc1 = relay.ty.TypeCall(gtv, [t1, t2])
tc2 = relay.ty.TypeCall(gtv, [t3, t3])
unified = solver.Unify(tc1, tc2)
assert unified == tc2
def test_unify_functype():
solver = make_solver()
t1 = relay.ty.IncompleteType()
......@@ -205,10 +229,49 @@ def test_bad_recursive_unification():
solver.Unify(t1, relay.ty.TupleType([t1, t1]))
@raises(tvm._ffi.base.TVMError)
def test_unify_invalid_global_typevars():
solver = make_solver()
gtv1 = relay.GlobalTypeVar('gtv1')
gtv2 = relay.GlobalTypeVar('gtv2')
solver.Unify(gtv1, gtv2)
@raises(tvm._ffi.base.TVMError)
def test_incompatible_typecall_var_unification():
solver = make_solver()
gtv1 = relay.GlobalTypeVar('gtv1')
gtv2 = relay.GlobalTypeVar('gtv2')
t1 = relay.IncompleteType()
t2 = relay.IncompleteType()
tc1 = relay.TypeCall(gtv1, [t1])
tc2 = relay.TypeCall(gtv2, [t2])
solver.Unify(tc1, tc2)
@raises(tvm._ffi.base.TVMError)
def test_incompatible_typecall_args_unification():
solver = make_solver()
gtv = relay.GlobalTypeVar('gtv1')
t1 = relay.IncompleteType()
t2 = relay.IncompleteType()
tensor1 = relay.TensorType((1, 2, 3), "float32")
tensor2 = relay.TensorType((2, 3), "float32")
tensor3 = relay.TensorType((3,), "float32")
tc1 = relay.TypeCall(gtv, [relay.TupleType([t1, t1]), t2])
tc2 = relay.TypeCall(gtv, [relay.TupleType([tensor1, tensor2]), tensor3])
solver.Unify(tc1, tc2)
if __name__ == "__main__":
test_bcast()
test_backward_solving()
test_unify_tuple()
test_unify_typecall()
test_unify_functype()
test_recursive_unify()
test_unify_vars_under_tuples()
......@@ -216,3 +279,5 @@ if __name__ == "__main__":
test_backward_solving_after_child_update()
test_incompatible_tuple_unification()
test_bad_recursive_unification()
test_incompatible_typecall_var_unification()
test_incompatible_typecall_args_unification()
from tvm import relay
from tvm.relay.ir_pass import infer_type
def test_dup_type():
a = relay.TypeVar("a")
av = relay.Var("av", a)
make_id = relay.Function([av], relay.Tuple([av, av]), None, [a])
t = relay.scalar_type("float32")
b = relay.Var("b", t)
assert relay.ir_pass.infer_type(make_id(b)).checked_type == relay.TupleType([t, t])
def test_id_type():
mod = relay.Module()
id_type = relay.GlobalTypeVar("id")
a = relay.TypeVar("a")
mod[id_type] = relay.TypeData(id_type, [a], [])
b = relay.TypeVar("b")
make_id = relay.Var("make_id", relay.FuncType([b], id_type(b), [b]))
t = relay.scalar_type("float32")
b = relay.Var("b", t)
assert relay.ir_pass.infer_type(make_id(b), mod).checked_type == id_type(t)
if __name__ == "__main__":
test_dup_type()
test_id_type()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment