Commit 51fe00fb by Jared Roesch Committed by Tianqi Chen

[High level OPT][RFC] NNVMv2 IR - Relay (#1672)

parent 543c4240
......@@ -104,6 +104,12 @@ file(GLOB COMPILER_SRCS
src/schedule/*.cc
)
file(GLOB_RECURSE RELAY_SRCS
src/relay/*.cc
)
list(APPEND COMPILER_SRCS ${RELAY_SRCS})
if(NOT MSVC)
file(GLOB COMPILER_VERILOG_SRCS src/codegen/verilog/*.cc)
list(APPEND COMPILER_SRCS ${COMPILER_VERILOG_SRCS})
......
......@@ -33,7 +33,7 @@ sys.path.insert(0, os.path.join(curr_path, '../vta/python'))
# General information about the project.
project = u'tvm'
author = u'%s developers' % project
copyright = u'2017, %s' % author
copyright = u'2018, %s' % author
github_doc_root = 'https://github.com/tqchen/tvm/tree/master/docs/'
# add markdown parser
......
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/base.h
* \brief Base classes for the Relay IR.
*/
#ifndef TVM_RELAY_BASE_H_
#define TVM_RELAY_BASE_H_
#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/node.h>
#include <string>
#include <vector>
namespace tvm {
/*!
* \brief Relay: a high level functional IR for TVM.
*
* This namespace contains the abstract syntax tree, and other
* essential data structures for the Relay IR.
*
* You can find more about Relay by reading the language reference.
*/
namespace relay {
/*!
* \brief we always used NodeRef for referencing nodes.
*
* By default, NodeRef is a std::shared_ptr of node
*/
using NodeRef = tvm::NodeRef;
/*!
* \brief Content data type.
*/
using DataType = ::tvm::Type;
/*!
* \brief Symbolic expression for tensor shape.
*/
using ShapeExpr = ::tvm::Expr;
/*!
* \brief Hash function for nodes.
* e.g. std::unordered_map<Expr, Value, NodeHash, NodeEqual>
*/
using NodeHash = ::tvm::NodeHash;
/*!
* \brief Equality check function for nodes.
*/
using NodeEqual = ::tvm::NodeEqual;
/*!
* \brief Macro to make it easy to define node ref type given node
* \param TypeName The name of the reference type.
* \param NodeName The internal container name.
* \param NodeRefBase The base type.
*/
#define RELAY_DEFINE_NODE_REF(TypeName, NodeName, NodeRefBase) \
class TypeName : public NodeRefBase { \
public: \
TypeName() {} \
explicit TypeName(std::shared_ptr<::tvm::Node> n) : NodeRefBase(n) {} \
const NodeName* operator->() const { \
return static_cast<const NodeName*>(node_.get()); \
} \
operator bool() { return this->defined(); } \
using ContainerType = NodeName; \
};
/*!
* \brief The source name in the Span
* \sa SourceNameNode, Span
*/
class SourceName;
/*!
* \brief The name of a source fragment.
*/
class SourceNameNode : public Node {
public:
/*! \brief The source name. */
std::string name;
// override attr visitor
void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); }
TVM_DLL static SourceName make(std::string name);
static constexpr const char* _type_key = "relay.SourceName";
TVM_DECLARE_NODE_TYPE_INFO(SourceNameNode, Node);
};
/*!
* \brief The source name of a file span.
* \sa SourceNameNode, Span
*/
class SourceName : public NodeRef {
public:
/*! \brief default constructor */
SourceName() {}
/*! \brief constructor from node pointer */
explicit SourceName(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const SourceNameNode* operator->() const;
/*!
* \brief Get an SourceName for a given operator name.
* Will raise an error if the source name has not been registered.
* \param name Name of the operator.
* \return Reference to a SourceName valid throughout program lifetime.
*/
TVM_DLL static const SourceName& Get(const std::string& name);
/*! \brief specify container node */
using ContainerType = SourceNameNode;
};
/*!
* \brief Span information for debugging purposes
*/
class Span;
/*!
* \brief Stores locations in frontend source that generated a node.
*/
class SpanNode : public Node {
public:
/*! \brief The source name */
SourceName source;
/*! \brief Line number */
int lineno;
/*! \brief column offset */
int col_offset;
// override attr visitor
void VisitAttrs(AttrVisitor* v) final {
v->Visit("source", &source);
v->Visit("lineno", &lineno);
v->Visit("col_offset", &col_offset);
}
TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
static constexpr const char* _type_key = "relay.Span";
TVM_DECLARE_NODE_TYPE_INFO(SpanNode, Node);
};
RELAY_DEFINE_NODE_REF(Span, SpanNode, NodeRef);
/*!
* \brief This is the base node container of all relay structures.
*/
class RelayNode : public Node {
public:
/*! \brief The location of the program in a SourceFragment can be null,
* check with span.defined() */
mutable Span span;
static constexpr const char* _type_key = "relay.Node";
TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node);
};
/*!
* \brief Get a reference type from a Node ptr type
*
* It is always important to get a reference type
* if we want to return a value as reference or keep
* the node alive beyond the scope of the function.
*
* \param ptr The node pointer
* \tparam RefType The reference type
* \tparam NodeType The node type
* \return The corresponding RefType
*/
template <typename RefType, typename NodeType>
RefType GetRef(const NodeType* ptr) {
static_assert(std::is_same<typename RefType::ContainerType, NodeType>::value,
"Can only cast to the ref of same container type");
return RefType(const_cast<NodeType*>(ptr)->shared_from_this());
}
// TODO(@tqchen, @jroesch): can we move these semantics to HalideIR
template <typename T>
inline const T* As(const NodeRef& node) {
const Node* ptr = static_cast<const Node*>(node.get());
if (ptr && (ptr->is_type<T>() || ptr->derived_from<T>())) {
return static_cast<const T*>(ptr);
}
return nullptr;
}
template <typename SubRef, typename BaseRef>
SubRef Downcast(BaseRef ref) {
CHECK(ref->template is_type<typename SubRef::ContainerType>())
<< "Downcast from " << ref->type_key() << " to "
<< SubRef::ContainerType::_type_key << " failed.";
return SubRef(ref.node_);
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_BASE_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/environment.h
* \brief The global environment: contains information needed to
* compile & optimize Relay programs.
*/
#ifndef TVM_RELAY_ENVIRONMENT_H_
#define TVM_RELAY_ENVIRONMENT_H_
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/type.h>
#include <string>
#include <vector>
namespace tvm {
namespace relay {
struct Environment;
/*! \brief The global environment of Relay programs.
*
* The global environment contains the global
* information needed to compile a Relay program.
*
* It contains all global functions, and configuration
* options.
*
* Many operations require access to the global
* Environment. We pass the Environment by value
* in a functional style as an explicit argument,
* but we mutate the Environment while optimizing
* Relay programs.
*
* The functional style allows users to construct custom
* environments easily, for example each thread can store
* an Environment while auto-tuning.
* */
class EnvironmentNode : public RelayNode {
public:
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, Function> functions;
EnvironmentNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions);
v->Visit("global_map_", &global_map_);
}
TVM_DLL static Environment make(tvm::Map<GlobalVar, Function> global_funcs);
/*! \brief Add a function to the global environment.
* \param var The name of the global function.
* \param func The function.
* \param update Controls whether you can replace a definition in the
* environment.
*/
void Add(const GlobalVar& var, const Function& func, bool update = false);
/*! \brief Update a function in the global environment.
* \param var The name of the global function to update.
* \param func The new function.
*/
void Update(const GlobalVar& var, const Function& func);
/*! \brief Remove a function from the global environment.
* \param var The name of the global function to update.
*/
void Remove(const GlobalVar& var);
/*! \brief Lookup a global function by its variable.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
GlobalVar GetGlobalVar(const std::string& str);
/*! \brief Lookup a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
Function Lookup(const GlobalVar& var);
/*! \brief Lookup a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
Function Lookup(const std::string& name);
/*! \brief Combine with another Environment.
* \param other The other environment.
*/
void Merge(const Environment& other);
static constexpr const char* _type_key = "relay.Environment";
TVM_DECLARE_NODE_TYPE_INFO(EnvironmentNode, Node);
private:
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
tvm::Map<std::string, GlobalVar> global_map_;
};
struct Environment : public NodeRef {
Environment() {}
explicit Environment(std::shared_ptr<tvm::Node> p) : NodeRef(p) {}
inline EnvironmentNode* operator->() const {
return static_cast<EnvironmentNode*>(node_.get());
}
using ContainerType = EnvironmentNode;
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ENVIRONMENT_H_
/*!
* Copyright (c) 2018 by Contributors
* \file error.h
* \brief The set of errors raised by Relay.
*/
#ifndef TVM_RELAY_ERROR_H_
#define TVM_RELAY_ERROR_H_
#include <string>
#include "./base.h"
namespace tvm {
namespace relay {
struct Error : dmlc::Error {
explicit Error(const std::string &msg) : dmlc::Error(msg) {}
};
struct InternalError : Error {
explicit InternalError(const std::string &msg) : Error(msg) {}
};
// TODO(@jroesch): we should change spanned errors to report
// errors against the Environment, inverting control to error definition.
struct FatalTypeError : dmlc::Error {
explicit FatalTypeError(const std::string &s) : dmlc::Error(s) {}
};
struct TypecheckerError : public dmlc::Error {
explicit TypecheckerError(const std::string &msg) : Error(msg) {}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ERROR_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/expr_functor.h
* \brief A more powerful visitor which enables defining arbitrary function
* signatures with type based dispatch on first argument.
*/
#ifndef TVM_RELAY_EXPR_FUNCTOR_H_
#define TVM_RELAY_EXPR_FUNCTOR_H_
#include <tvm/ir_functor.h>
#include <string>
#include "./expr.h"
#include "./op.h"
namespace tvm {
namespace relay {
/*!
* \brief A dynamical functor that dispatches on in the first Expr argument.
* You can use this as a more powerful Visitor, since it allows you to
* define function signatures of Visit Function.
*
* \sa tvm/ir_functor.h
*
* \tparam FType function signiture
* This type is only defined for FType with function signature R(const Expr&,
* Args...)
*/
template <typename FType>
class ExprFunctor;
// functions to be overriden.
#define EXPR_FUNCTOR_DEFAULT \
{ return VisitExprDefault_(op, std::forward<Args>(args)...); }
#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitExpr_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
});
template <typename R, typename... Args>
class ExprFunctor<R(const Expr& n, Args...)> {
private:
using TSelf = ExprFunctor<R(const Expr& n, Args...)>;
using FType = tvm::IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*! \brief virtual destructor */
virtual ~ExprFunctor() {}
/*!
* \brief Same as call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
R operator()(const Expr& n, Args... args) {
return VisitExpr(n, std::forward<Args>(args)...);
}
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitExpr(const Expr& n, Args... args) {
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
virtual R VisitExpr_(const ConstantNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const VarNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const GlobalVarNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const ParamNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const FunctionNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const IfNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const OpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw dmlc::Error(std::string("Do not have a default for ") + op->type_key());
}
private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
// Set dispatch
RELAY_EXPR_FUNCTOR_DISPATCH(ConstantNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleNode);
RELAY_EXPR_FUNCTOR_DISPATCH(VarNode);
RELAY_EXPR_FUNCTOR_DISPATCH(GlobalVarNode);
RELAY_EXPR_FUNCTOR_DISPATCH(ParamNode);
RELAY_EXPR_FUNCTOR_DISPATCH(FunctionNode);
RELAY_EXPR_FUNCTOR_DISPATCH(CallNode);
RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
return vtable;
}
};
/*! \brief A simple visitor wrapper around ExprFunctor.
*
* Exposes two visitors with default traversal strategies, one
* which doesn't compute a result but can mutate internal state,
* and another which functionally builds a new Expr.
*/
class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
public:
void VisitExpr_(const VarNode* op) override;
void VisitExpr_(const GlobalVarNode* op) override;
void VisitExpr_(const ConstantNode* op) override;
void VisitExpr_(const TupleNode* op) override;
void VisitExpr_(const ParamNode* op) override;
void VisitExpr_(const FunctionNode* op) override;
void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const LetNode* op) override;
void VisitExpr_(const IfNode* op) override;
void VisitExpr_(const OpNode* op) override;
virtual void VisitType(const Type& t);
};
/*! \brief A wrapper around ExprFunctor which functionally updates the AST.
*
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/
class ExprMutator
: public ::tvm::relay::ExprFunctor<Expr(const Expr&, const Expr&)> {
public:
Expr Mutate(const Expr& expr);
Expr VisitExpr_(const VarNode* op, const Expr& e) override;
Expr VisitExpr_(const ConstantNode* op, const Expr& e) override;
Expr VisitExpr_(const GlobalVarNode* op, const Expr& e) override;
Expr VisitExpr_(const OpNode* op, const Expr& expr) override;
Expr VisitExpr_(const TupleNode* op, const Expr& e) override;
Expr VisitExpr_(const ParamNode* op, const Expr& e) override;
Expr VisitExpr_(const FunctionNode* op, const Expr& e) override;
Expr VisitExpr_(const CallNode* call_node, const Expr& e) override;
Expr VisitExpr_(const LetNode* op, const Expr& e) override;
Expr VisitExpr_(const IfNode* op, const Expr& e) override;
/*! \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
* ways, one way would be to define a sub-class of type
* visitor for types which transform them appropriately.
*/
virtual Type VisitType(const Type& t);
private:
/*! \brief Internal map used for memoization. */
tvm::Map<Expr, Expr> memo_;
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_FUNCTOR_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/logging.h
* \brief A wrapper around dmlc-core/logging.h which adds the ability
* to toggle logging via an environment variable.
*/
#ifndef TVM_RELAY_LOGGING_H_
#define TVM_RELAY_LOGGING_H_
#include <dmlc/logging.h>
#include <string>
#include <cstdlib>
#include <iostream>
namespace tvm {
namespace relay {
static bool logging_enabled() {
if (auto var = std::getenv("RELAY_LOG")) {
std::string is_on(var);
return is_on == "1";
} else {
return false;
}
}
#define RELAY_LOG(severity) LOG_IF(severity, logging_enabled())
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_LOGGING_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/pass.h
* \brief The set of Relay passes written in C++.
*/
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_
#include <tvm/relay/environment.h>
#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
/*! \brief Infer the type of an expression with the provided environment.
*
* The result of type checking is a new expression with unambigous
* type information filled in, as well as it's checked type field
* populated with the result type.
*
* \param env The environment used for global settings and referencing
* global functions.
*
* \param e The expression to type check.
*
* \return A type checked expression with its checked_type field populated.
*/
Expr InferType(const Environment& env, const Expr& e);
Expr InferType(const Environment& env, const GlobalVar& v, const Function& e);
/*!
* \brief Check that types are well formed by applying "kinding rules".
*
* This pass ensures we do not do things that violate the design of the
* type system when writing down types.
*
* For example tensors are not allowed to contain functions in Relay.
*
* We check this by ensuring the `dtype` field of a Tensor always contains
* a data type such as `int`, `float`, `uint`.
*
* \param env The global environment.
* \param t The type to check.
* \return true if the rules are satisified otherwise false
*/
bool KindCheck(const Environment& env, const Type& t);
/*! \brief Compare two expressions for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
*
* For example: `let x = 1 in x` is equal to `let y = 1 in y`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param e1 The left hand expression.
* \param e2 The right hand expression.
*
* \return true if equal, otherwise false
*/
bool AlphaEqual(const Expr& e1, const Expr& e2);
/*! \brief Compare two types for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
*
* For example: `forall s, Tensor[f32, s]` is equal to
* `forall w, Tensor[f32, w]`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param t1 The left hand type.
* \param t2 The right hand type.
*
* \return true if equal, otherwise false
*/
bool AlphaEqual(const Type& t1, const Type& t2);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_H_
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/type.h
* \brief Relay typed AST nodes.
*/
#ifndef TVM_RELAY_TYPE_H_
#define TVM_RELAY_TYPE_H_
#include <tvm/api_registry.h>
#include <tvm/ir.h>
#include <tvm/node.h>
#include <string>
#include "./base.h"
namespace tvm {
namespace relay {
/*! \brief Base type of the Relay type hiearchy. */
class TypeNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Type";
TVM_DECLARE_BASE_NODE_INFO(TypeNode, Node);
};
/*!
* \brief Type is the base type of relay type hiearchy.
*
* Relay's type system contains following two key concepts:
*
* - TensorType: type of certain Tensor values in the expression.
* - FunctionType: the type of the function.
*
* There are also advanced types to support generic(polymorphic types),
* which can be ignored when first reading the code base.
*/
class Type : public NodeRef {
public:
Type() {}
explicit Type(std::shared_ptr<tvm::Node> p) : NodeRef(p) {}
using ContainerType = TypeNode;
};
/*!
* \brief Base of all Tensor types
* This container can hold TensorType or GenericTensorType.
*/
class BaseTensorTypeNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.BaseTensorType";
TVM_DECLARE_BASE_NODE_INFO(BaseTensorTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(BaseTensorType, BaseTensorTypeNode, Type);
/*!
* \brief This is the most commonly used type in relay.
* TensorType have a fixed dimension, data type.
*
* The elements of shape can be either IntImm(constant integer),
* or any symbolic integer expression.
* The symbolic integer allows generic shape inference in certain cases.
* \sa TensorTypeNode The container class of TensorType.
*/
class TensorType;
/*! \brief TensorType container node */
class TensorTypeNode : public BaseTensorTypeNode {
public:
/*!
* \brief The shape of the tensor,
* represented by ShapeExpr(tvm::Expr).
*/
Array<ShapeExpr> shape;
/*! \brief The content data type */
DataType dtype;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
v->Visit("span", &span);
}
TVM_DLL static TensorType make(Array<ShapeExpr> shape, DataType dtype);
/*! \brief Construct an scalar containing elements of dtype. */
TVM_DLL static TensorType Scalar(DataType dtype);
static constexpr const char* _type_key = "relay.TensorType";
TVM_DECLARE_NODE_TYPE_INFO(TensorTypeNode, BaseTensorTypeNode);
};
RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type);
/*!
* \brief Type parameter in the function.
* This can be viewed as template parameter in c++ template function.
*
* For example, in the following pesudo code,
* the TypeParam of f is TypeParam(kind=kShapeVar, var=n).
* This function can take in a Tensor with shape=(3, 3) and
* returns a Tensor with shape=(9,)
*
* \code
*
* template<i32 n>
* f(x : Tensor[i32, (n, n)]) -> Tensor[i32, (n * n)]
*
* \endcode
* \sa TypeParamNode The actual container class of TypeParam
*/
class TypeParam;
/*! \brief TypeParam container node */
class TypeParamNode : public TypeNode {
public:
/*! \brief possible kinds of TypeParam */
enum Kind : int {
/*! \brief template variable in shape expression */
kShapeVar = 0,
kShape = 1,
kBaseType = 2,
kType = 3
};
/*!
* \brief The variable itself is only meaningful when
* kind is ShapeVar, otherwise, we only use the name.
*/
tvm::Var var;
/*! \brief The kind of type parameter */
Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("kind", &kind);
v->Visit("span", &span);
}
TVM_DLL static TypeParam make(std::string name, Kind kind);
static constexpr const char* _type_key = "relay.TypeParam";
TVM_DECLARE_NODE_TYPE_INFO(TypeParamNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TypeParam, TypeParamNode, Type);
/*!
* \brief Potential Constraints in the type.
* \note This is reserved for future use.
*/
class TypeConstraint;
/*! \brief TypeConstraint container node. */
class TypeConstraintNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.TypeConstraint";
TVM_DECLARE_BASE_NODE_INFO(TypeConstraintNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TypeConstraint, TypeConstraintNode, Type);
class FuncType;
/*!
* \brief Function type in Relay.
*
* Relay support polymorphic function type.
* This can be roughly viewed as template function in C++.
*
* \sa TypeParam, TypeConstraint
*/
class FuncTypeNode : public TypeNode {
public:
/*! \brief type type of arguments */
tvm::Array<Type> arg_types;
/*! \brief The type of return value. */
Type ret_type;
// The following fields are used in polymorphic(template) functions
// For normal functions, the following two fields will be empty.
/*! \brief The type parameters of the function */
tvm::Array<TypeParam> type_params;
/*!
* \brief potential constraint the type need to obey
* \note this field is reserved for futher purposes.
*/
tvm::Array<TypeConstraint> type_constraints;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("arg_types", &arg_types);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("type_constraints", &type_constraints);
v->Visit("span", &span);
}
TVM_DLL static FuncType make(tvm::Array<Type> arg_types, Type ret_type,
tvm::Array<TypeParam> type_params,
tvm::Array<TypeConstraint> type_constraints);
static constexpr const char* _type_key = "relay.FuncType";
TVM_DECLARE_NODE_TYPE_INFO(FuncTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(FuncType, FuncTypeNode, Type);
using TypeRelationFn =
TypedEnvFunc<Array<Type>(const Array<Type>&, int)>;
/*!
* \brief Opaque type relation, is an input-output relation on types.
*/
class TypeRelation;
/*!
* \brief TypeRelation container.
* \note This node is not directly serializable.
* The type function need to be lookedup in the environment.
*/
class TypeRelationNode : public TypeConstraintNode {
public:
/*! \brief The name of the function */
std::string name;
/*!
* \brief The function on input and output variables which
* this is not directly serializable,
* need to be looked-up in the environment.
*/
TypeRelationFn func_;
/*! \brief The type arguments to the type function. */
tvm::Array<Type> args;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name", &name);
}
TVM_DLL static TypeRelation make(std::string name, TypeRelationFn func_, Array<Type> args);
static constexpr const char* _type_key = "relay.TypeRelation";
TVM_DECLARE_NODE_TYPE_INFO(TypeRelationNode, TypeConstraintNode);
};
RELAY_DEFINE_NODE_REF(TypeRelation, TypeRelationNode, TypeConstraint);
/*!
* \brief The type of tuple values.
*/
class TupleType;
/*!
* \brief TupleType container.
*/
class TupleTypeNode : public TypeNode {
public:
/*! \brief The type of each field in the tuple. */
tvm::Array<Type> fields;
TupleTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("fields", &fields); }
TVM_DLL static TupleType make(tvm::Array<Type> fields);
static constexpr const char* _type_key = "relay.TypeTuple";
TVM_DECLARE_NODE_TYPE_INFO(TupleTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type);
// The following fields contains advanced typing
// Only keep the class name and reserved for future usage.
class GenericTensorType;
// stores a DataType.
class GenericDataType;
// stores a DataType.
class GenericShape;
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_TYPE_H_
# pylint: disable=wildcard-import
"""The Relay IR namespace containing the IR definition and compiler."""
from . import base
from . import ty
from . import expr
from . import env
from . import ir_pass
from . import ir_builder
# Operators
from .op import Op
from .op.tensor import *
# Span
Span = base.Span
# Type
Type = ty.Type
TensorType = ty.TensorType
Kind = ty.Kind
TypeParam = ty.TypeParam
TypeConstraint = ty.TypeConstraint
FuncType = ty.FuncType
# Expr
Constant = expr.Constant
Tuple = expr.Tuple
Var = expr.Var
GlobalVar = expr.GlobalVar
Param = expr.Param
Function = expr.Function
Call = expr.Call
Let = expr.Let
If = expr.If
Var = Var
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface to the Environment exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay._env", __name__)
from typing import Union, Tuple, Dict, List
from relay.ir import GlobalId, OperatorId, Item, NodeBase, Span, FileId
from relay.ir import ShapeExtension, Operator, Defn
class Environment(NodeBase): ...
\ No newline at end of file
"""FFI exposing the Relay type inference and checking."""
from tvm._ffi.function import _init_api
_init_api("relay._ir_pass", __name__)
from .env import Environment
from . import ir
def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ...
def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ...
def _get_checked_type(expr: ir.Expr) -> ir.Type: ...
"""
The constructors for all Relay AST nodes exposed from C++.
This module includes MyPy type signatures for all of the
exposed modules.
"""
from .._ffi.function import _init_api
_init_api("relay._make", __name__)
# pylint: disable=no-else-return, unidiomatic-typecheck
"""The base node types for the Relay language."""
from __future__ import absolute_import as _abs
from .._ffi.node import NodeBase, register_node as _register_tvm_node
from . import _make
NodeBase = NodeBase
def register_relay_node(type_key=None):
"""register relay node type
Parameters
----------
type_key : str or cls
The type key of the node
"""
if not isinstance(type_key, str):
return _register_tvm_node(
"relay." + type_key.__name__)(type_key)
return _register_tvm_node(type_key)
@register_relay_node
class Span(NodeBase):
def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, wildcard-import
"""A global environment storing everything needed to interpret or compile a Relay program."""
from .base import register_relay_node, NodeBase
from . import _make
from . import _env
@register_relay_node
class Environment(NodeBase):
"""The global Relay environment containing functions,
options and more.
"""
def __init__(self, funcs):
"""Construct an environment.
Parameters
------
funcs: list of relay.Function
Returns
------
env: A new environment containing :py:class:`~relay.env.Environment`.
"""
self.__init_handle_by_constructor__(_make.Environment, funcs)
def add(self, var, func):
"""Add a function to the environment.
Parameters
---------
var: GlobalVar
The global variable which names the function.
func: Function
The function.
"""
if isinstance(var, str):
var = _env.Environment_GetGlobalVar(self, var)
_env.Environment_Add(self, var, func)
def merge(self, other):
"""Merge two environments.
Parameters
----------
other: Environment
The environment to merge into the current Environment.
"""
return _env.Environment_Merge(self, other)
def global_var(self, name):
"""Get a global variable by name.
Parameters
----------
name: str
The name of the global variable.
Returns
-------
global_var: GlobalVar
The global variable mapped to :code:`name`.
"""
return _env.Environment_GetGlobalVar(self, name)
def __getitem__(self, var):
"""Lookup a global function by name or by variable.
Parameters
----------
var: str or GlobalVar
The name or global variable.
Returns
-------
func: Function
The function referenced by :code:`var`.
"""
if isinstance(var, str):
return _env.Environment_Lookup_str(self, var)
else:
return _env.Environment_Lookup(self, var)
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression nodes of Relay."""
from __future__ import absolute_import
from .base import NodeBase, register_relay_node
from ._ir_pass import _get_checked_type
from . import _make
from .. import convert
class Expr(NodeBase):
"""The base type for all Relay expressions."""
def checked_type(self):
return _get_checked_type(self)
def __call__(self, *args):
converted_args = []
for arg in args:
if isinstance(arg, Param):
converted_args.append(arg.var)
else:
converted_args.append(arg)
return Call(self, args, None, None)
@register_relay_node
class Constant(Expr):
"""A constant tensor in Relay, see tvm/relay/type.h for more details.
"""
def __init__(self, data):
self.__init_handle_by_constructor__(_make.Constant, data)
@register_relay_node
class Tuple(Expr):
"""A hetereogenous sequence of values.
see tvm/relay/type.h for more details.
"""
def __init__(self, fields):
self.__init_handle_by_constructor__(_make.Tuple, fields)
@register_relay_node
class Var(Expr):
"""A local variable in Relay."""
def __init__(self, name_hint):
self.__init_handle_by_constructor__(_make.Var, name_hint)
@register_relay_node
class GlobalVar(Expr):
"""A global variable in Relay."""
def __init__(self, name_hint):
self.__init_handle_by_constructor__(_make.GlobalVar, name_hint)
@register_relay_node
class Param(Expr):
"""A function type in Relay, see tvm/relay/type.h for more details.
"""
def __init__(self, var, ty):
self.__init_handle_by_constructor__(_make.Param, var, ty)
@register_relay_node
class Function(Expr):
"""A function in Relay, see tvm/relay/expr.h for more details."""
def __init__(self,
params,
ret_type,
body,
type_params=None
):
if type_params is None:
type_params = convert([])
self.__init_handle_by_constructor__(
_make.Function, params, ret_type, body, type_params)
@register_relay_node
class Call(Expr):
"""A function call in Relay, see tvm/relay/expr.h for more details."""
def __init__(self, op, args, attrs, ty_args=None):
if not ty_args:
ty_args = []
self.__init_handle_by_constructor__(
_make.Call, op, args, attrs, ty_args)
@register_relay_node
class Let(Expr):
"""A variable bindings in Relay, see tvm/relay/expr.h for more details."""
def __init__(self, var, value, body, value_type):
self.__init_handle_by_constructor__(
_make.Let, var, value, body, value_type)
@register_relay_node
class If(Expr):
"""A conditional expression in Relay, see tvm/relay/expr.h for more details."""
def __init__(self, cond, true_value, false_value):
self.__init_handle_by_constructor__(
_make.If, cond, true_value, false_value)
from typing import List
import tvm
from .base import Span, NodeBase
from .ty import Type, TypeParam
from ._ir_pass import _get_checked_type
class Expr(NodeBase):
def checked_type(self):
...
def __call__(self, *args):
...
class Constant(Expr):
data = ... # type: tvm.nd.NDArray
def __init__(self, data):
# type: (tvm.nd.NDArray) -> None
...
class Tuple(Expr):
fields = .. # type: List[Expr]
def __init__(self, fields):
# type: (List[Expr]) -> None
...
class Var(Expr):
"""A local variable in Relay."""
name_hint = ... # type: str
def __init__(self, name_hint):
# type: (str) -> None
...
class GlobalVar(Expr):
name_hint = ... # type: str
def __init__(self, name_hint):
# type: (str) -> None
...
class Param(Expr):
var = ... # type: Var
type = ... # type: Type
def __init__(self, var, ty):
# type: (Var, Type) -> None
...
class Function(Expr):
"""A function in Relay, see tvm/relay/expr.h for more details."""
type_params = ... # type: List[TypeParam]
params = ... # type: List[Param]
ret_type = ... # type: Type
body = ... # type: Expr
def __init__(self,
params, # type: List[Param],
ret_type, # type: Type,
body, # type: Expr,
type_params=None, # type: List[TypeParam]
):
# type: (...) -> None
...
@register_relay_node
class Call(Expr):
"""A function call in Relay, see tvm/relay/expr.h for more details."""
op = ... # type: Expr
args = ... # type: List[Expr]
# todo(@jroesch): add attrs
def __init__(self, op, args, attrs, ty_args=None):
# type: (Expr, List[Expr], Optional[List[Type]]) -> None
if not ty_args:
ty_args = []
self.__init_handle_by_constructor__(
_make.Call, op, args, attrs, ty_args)
@register_relay_node
class Let(Expr):
"""A variable bindings in Relay, see tvm/relay/expr.h for more details."""
var = ... # type: Var
value = ... # type: Expr
body = ... # type: Expr
value_type = ... # type: Type
def __init__(self, var, value, body, value_type):
# type: (Var, Expr, Expr, Type) -> None
...
@register_relay_node
class If(Expr):
"""A conditional expression in Relay, see tvm/relay/expr.h for more details."""
cond = ... # type: Expr
true_value = ... # type: Expr
false_value = ... # type: Expr
span = ... # type: Span
def __init__(self, cond, true_value, false_value):
# type: (Expr, Expr, Expr) -> None
...
# pylint: disable=no-else-return,
# pylint: disable=unidiomatic-typecheck
"""The set of passes for Relay.
Exposes an interface for configuring the passes and scripting
them in Python.
"""
from . import _ir_pass
# Expose checking expression, should rename to infer_type.
# pylint: disable=invalid-name
check_expr = _ir_pass.check_expr
#pylint: disable=wildcard-import
"""Relay core operators."""
# operator defs
from .op import get, register, Op
# Operators
from .tensor import *
# operator registry
from . import _tensor
from ..expr import Expr
from ..base import register_relay_node
"""Constructor APIs"""
from ..._ffi.function import _init_api
_init_api("relay.op._make", __name__)
#pylint: disable=invalid-name
"""Backend compiler related feature registration"""
"""The base node types for the Relay language."""
from ..._ffi.function import _init_api
from ..base import register_relay_node
from ..expr import Expr
@register_relay_node
class Op(Expr):
"""A Relay operator definition."""
def __init__(self):
raise RuntimeError("Cannot create op, use get instead")
def get_attr(self, attr_name):
"""Get additional attribute about the operator.
Parameters
----------
attr_name : str
The attribute name.
Returns
-------
value : object
The attribute value
"""
return _OpGetAttr(self, attr_name)
def get(op_name):
"""Get the Op for a given name
Parameters
----------
op_name : str
The operator name
Returns
-------
op : Op
The op of the corresponding name
"""
return _GetOp(op_name)
def register(op_name, attr_key, value=None, level=10):
"""Register an operator property of an operator.
Parameters
----------
op_name : str
The name of operator
attr_key : str
The attribute name.
value : object, optional
The value to set
level : int, optional
The priority level
Returns
-------
fregister : function
Register function if value is not specified.
"""
def _register(v):
"""internal register function"""
_Register(op_name, attr_key, v, level)
return v
return _register(value) if value else _register
_init_api("relay.op", __name__)
"""Basic tensor operations."""
from __future__ import absolute_import as _abs
from . import _make
from ..expr import Tuple
# We create a wrapper function for each operator in the
# python side to call into the positional _make.OpName function.
#
# We make this decision so that we can:
# - Have declare python docstring for each function
# - Enable keyword arguments easily
# - Not put too much burden on FFI to support complicated features
# like default value and keyword arguments
def log(data):
"""Take log of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.log(data)
def exp(data):
"""Take exp of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.exp(data)
def sqrt(data):
"""Take sqrt of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.sqrt(data)
def add(lhs, rhs):
"""Elementwise addition.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.add(lhs, rhs)
def subtract(lhs, rhs):
"""Elementwise subtraction.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.add(lhs, rhs)
def equal(lhs, rhs):
return _make.equal(lhs, rhs)
def concat(*args):
"""Concatenate the input tensors along the zero axis.
Parameters
----------
args: list of Tensor
Returns
-------
tensor: The concatenated tensor.
"""
tup = Tuple(list(args))
return _make.concat(tup)
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language."""
from enum import IntEnum
from .base import NodeBase, register_relay_node
from . import _make
class Type(NodeBase):
"""The base type for all Relay types."""
def __eq__(self, other):
"""Compare two Relay types for structural equivalence using
alpha equivalence.
"""
return bool(_make._type_alpha_eq(self, other))
def __ne__(self, other):
return not self.__eq__(other)
def same_as(self, other):
"""Compares two Relay types by referential equality."""
return super().__eq__(other)
@register_relay_node
class TensorType(Type):
"""A concrete TensorType in Relay, see tvm/relay/type.h for more details.
This is the type assigned to tensor's with a known dype and shape. For
example a tensor of `float32` and `(5, 5)`.
"""
def __init__(self, shape, dtype):
"""Construct a tensor type.
Parameters
----------
shape: list of tvm.Expr
dtype: str
Returns
-------
tensor_type: The TensorType
"""
self.__init_handle_by_constructor__(_make.TensorType, shape, dtype)
class Kind(IntEnum):
"""The kind of a type parameter, represents a variable shape,
base type, type, or dimension.
This controls what a type parameter is allowed to be instantiated
with. For example one's of kind BaseType can only be `float32`, `int32`,
and so on.
"""
ShapeVar = 0
Shape = 1
BaseType = 2
Type = 3
@register_relay_node
class TypeParam(Type):
"""A type parameter used for generic types in Relay,
see tvm/relay/type.h for more details.
A type parameter represents a type placeholder which will
be filled in later on. This allows the user to write
functions which are generic over types.
"""
def __init__(self, var, kind):
"""Construct a TypeParam.
Parameters
----------
var: tvm.expr.Var
The tvm.Var which backs the type parameter.
kind: Kind
The kind of the type parameter.
Returns
-------
type_param: TypeParam
The type parameter.
"""
self.__init_handle_by_constructor__(_make.TypeParam, var, kind)
@register_relay_node
class TypeConstraint(Type):
"""Abstract class representing a type constraint."""
pass
@register_relay_node
class FuncType(Type):
"""A function type in Relay, see tvm/relay/type.h for more details.
This is the type assigned to functions in Relay. They consist of
a list of type parameters which enable the definition of generic
fucntions, a set of type constraints which we omit for the time
being, a sequence of argument types, and a return type.
We informally write them as:
`forall (type_params), (arg_types) -> ret_type where type_constraints`
"""
def __init__(self,
arg_types,
ret_type,
type_params,
type_constraints
):
"""Construct a function type.
Parameters
----------
arg_types: list of Type
ret_type: Type
type_params: list of TypeParam
type_constraints: list of TypeConstraint
Returns
-------
func_type: FuncType
The function type.
"""
self.__init_handle_by_constructor__(
_make.FuncType, arg_types, ret_type, type_params, type_constraints)
@register_relay_node
class IncompleteType(Type):
"""An incomplete type."""
def __init__(self, kind):
self.__init_handle_by_constructor__(_make.IncompleteType, kind)
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The type nodes of the Relay language."""
from enum import IntEnum
from .base import NodeBase, register_relay_node
from . import _make
class Type(NodeBase):
"""The base type for all Relay types."""
def __eq__(self, other):
"""Compare two Relay types for structural equivalence using
alpha equivalence.
"""
return bool(_make._type_alpha_eq(self, other))
def __ne__(self, other):
return not self.__eq__(other)
def same_as(self, other):
"""Compares two Relay types by referential equality."""
return super().__eq__(other)
@register_relay_node
class TensorType(Type):
"""A concrete TensorType in Relay, see tvm/relay/type.h for more details.
This is the type assigned to tensor's with a known dype and shape. For
example a tensor of `float32` and `(5, 5)`.
"""
def __init__(self, shape, dtype):
"""Construct a tensor type.
Parameters
----------
shape: list of tvm.Expr
dtype: str
Returns
-------
tensor_type: The TensorType
"""
self.__init_handle_by_constructor__(_make.TensorType, shape, dtype)
class Kind(IntEnum):
"""The kind of a type parameter, represents a variable shape,
base type, type, or dimension.
This controls what a type parameter is allowed to be instantiated
with. For example one's of kind BaseType can only be `float32`, `int32`,
and so on.
"""
ShapeVar = 0
Shape = 1
BaseType = 2
Type = 3
@register_relay_node
class TypeParam(Type):
"""A type parameter used for generic types in Relay,
see tvm/relay/type.h for more details.
A type parameter represents a type placeholder which will
be filled in later on. This allows the user to write
functions which are generic over types.
"""
def __init__(self, var, kind):
"""Construct a TypeParam.
Parameters
----------
var: tvm.expr.Var
The tvm.Var which backs the type parameter.
kind: Kind
The kind of the type parameter.
Returns
-------
type_param: TypeParam
The type parameter.
"""
self.__init_handle_by_constructor__(_make.TypeParam, var, kind)
@register_relay_node
class TypeConstraint(Type):
"""Abstract class representing a type constraint."""
pass
@register_relay_node
class FuncType(Type):
"""A function type in Relay, see tvm/relay/type.h for more details.
This is the type assigned to functions in Relay. They consist of
a list of type parameters which enable the definition of generic
fucntions, a set of type constraints which we omit for the time
being, a sequence of argument types, and a return type.
We informally write them as:
`forall (type_params), (arg_types) -> ret_type where type_constraints`
"""
def __init__(self,
arg_types,
ret_type,
type_params,
type_constraints,
):
"""Construct a function type.
Parameters
----------
arg_types: list of Type
ret_type: Type
type_params: list of TypeParam
type_constraints: list of TypeConstraint
Returns
-------
func_type: FuncType
The function type.
"""
self.__init_handle_by_constructor__(
_make.FuncType, arg_types, ret_type, type_params, type_constraints)
@register_relay_node
class IncompleteType(Type):
"""An incomplete type."""
def __init__(self, kind):
self.__init_handle_by_constructor__(_make.IncompleteType, kind)
......@@ -6,8 +6,10 @@ from . import _api_internal
from . import make as _make
from . import expr as _expr
class TensorSlice(NodeGeneric, _expr.ExprOp):
"""Auxiliary data structure for enable slicing syntax from tensor."""
def __init__(self, tensor, indices):
if not isinstance(indices, tuple):
indices = (indices,)
......@@ -31,9 +33,11 @@ class TensorSlice(NodeGeneric, _expr.ExprOp):
itervar_cls = None
@register_node
class Tensor(NodeBase, _expr.ExprOp):
"""Tensor object, to construct, see function.Tensor"""
def __call__(self, *indices):
ndim = self.ndim
if len(indices) != ndim:
......@@ -104,6 +108,7 @@ class Tensor(NodeBase, _expr.ExprOp):
class Operation(NodeBase):
"""Represent an operation that generate a tensor"""
def output(self, index):
"""Get the index-th output of the operation
......
/*!
* Copyright (c) 2018 by Contributors
* \file base.cc
* \brief The core base types for Relay.
*/
#include <tvm/api_registry.h>
#include <tvm/relay/base.h>
namespace tvm {
namespace relay {
using tvm::IRPrinter;
using namespace tvm::runtime;
SourceName SourceNameNode::make(std::string name) {
std::shared_ptr<SourceNameNode> n = std::make_shared<SourceNameNode>();
n->name = std::move(name);
return SourceName(n);
}
std::shared_ptr<SourceNameNode> CreateSourceName(const std::string& name) {
SourceName sn = SourceName::Get(name);
CHECK(!sn.defined()) << "Cannot find source name \'" << name << '\'';
std::shared_ptr<Node> node = sn.node_;
return std::dynamic_pointer_cast<SourceNameNode>(node);
}
const SourceName& SourceName::Get(const std::string& name) {
static std::unordered_map<std::string, SourceName> source_map;
auto sn = source_map.find(name);
if (sn == source_map.end()) {
auto source_name = SourceNameNode::make(name);
source_map.insert({name, source_name});
return source_map.at(name);
} else {
return sn->second;
}
}
TVM_REGISTER_API("relay._make.SourceName")
.set_body([](tvm::TVMArgs args, tvm::TVMRetValue *ret) {
*ret = SourceNameNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const SourceNameNode *node, tvm::IRPrinter *p) {
p->stream << "SourceNameNode(" << node->name << ", " << node << ")";
});
TVM_REGISTER_NODE_TYPE(SourceNameNode)
.set_creator(CreateSourceName)
.set_global_key([](const Node* n) {
return static_cast<const SourceNameNode*>(n)->name;
});
Span SpanNode::make(SourceName source, int lineno, int col_offset) {
std::shared_ptr<SpanNode> n = std::make_shared<SpanNode>();
n->source = std::move(source);
n->lineno = lineno;
n->col_offset = col_offset;
return Span(n);
}
TVM_REGISTER_API("relay._make.Span")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = SpanNode::make(args[0], args[1], args[2]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SpanNode>([](const SpanNode *node, tvm::IRPrinter *p) {
p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", "
<< node->col_offset << ")";
});
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file environment.cc
* \brief The global environment in Relay.
*/
#include <tvm/relay/environment.h>
#include <tvm/relay/pass.h>
#include <sstream>
#include "./../pass/resolve.h"
namespace tvm {
namespace relay {
using tvm::IRPrinter;
using namespace runtime;
Environment EnvironmentNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
std::shared_ptr<EnvironmentNode> n = std::make_shared<EnvironmentNode>();
n->functions = std::move(global_funcs);
return Environment(n);
}
GlobalVar EnvironmentNode::GetGlobalVar(const std::string &str) {
auto global_id = global_map_.find(str);
if (global_id != global_map_.end()) {
return (*global_id).second;
} else {
auto id = GlobalVarNode::make(str);
this->global_map_.Set(str, id);
return id;
}
}
/*! \brief Add a new item to the global environment
* \note if the update flag is not set adding a duplicate
* definition will trigger an exception, otherwise we will
* update the definition if and only if it is type compatible.
*/
void EnvironmentNode::Add(const GlobalVar &var, const Function &func,
bool update) {
// Type check the item before we add it to the environment.
auto env = GetRef<Environment>(this);
Expr checked_expr = InferType(env, var, func);
if (const FunctionNode *func_node = checked_expr.as<FunctionNode>()) {
auto checked_func = GetRef<Function>(func_node);
auto type = checked_func->checked_type();
CHECK(IsFullyResolved(type));
if (functions.find(var) != functions.end()) {
if (!update) {
throw dmlc::Error("already have definition for XXXX.");
}
auto old_type = functions[var].as<FunctionNode>()->checked_type();
if (!AlphaEqual(type, old_type)) {
throw dmlc::Error(
"Environment#update changes type, not possible in this mode.");
}
this->functions.Set(var, checked_func);
} else {
this->functions.Set(var, checked_func);
}
} else {
throw Error("internal error: unknown item type, unreachable code");
}
}
void EnvironmentNode::Update(const GlobalVar &var, const Function &func) {
this->Add(var, func, true);
}
void EnvironmentNode::Remove(const GlobalVar & var) {
auto functions_node = this->functions.CopyOnWrite();
functions_node->data.erase(var.node_);
}
Function EnvironmentNode::Lookup(const GlobalVar &var) {
auto func = functions.find(var);
if (func != functions.end()) {
return (*func).second;
} else {
throw Error(std::string("there is no definition of ") + var->name_hint);
}
}
Function EnvironmentNode::Lookup(const std::string &str) {
GlobalVar id = this->GetGlobalVar(str);
return this->Lookup(id);
}
void EnvironmentNode::Merge(const Environment &env) {
for (auto pair : env->functions) {
this->functions.Set(pair.first, pair.second);
}
}
TVM_REGISTER_API("relay._make.Environment")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = EnvironmentNode::make(args[0]);
});
TVM_REGISTER_API("relay._env.Environment_Add")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
env->Add(args[1], args[2], false);
});
TVM_REGISTER_API("relay._env.Environment_GetGlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
*ret = env->GetGlobalVar(args[1]);
});
TVM_REGISTER_API("relay._env.Environment_Lookup")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
GlobalVar var = args[1];
*ret = env->Lookup(var);
});
TVM_REGISTER_API("relay._env.Environment_Lookup_str")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
std::string var_name = args[1];
auto var = env->GetGlobalVar(var_name);
*ret = env->Lookup(var);
});
TVM_REGISTER_API("relay._env.Environment_Merge")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Environment env = args[0];
env->Merge(args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<EnvironmentNode>([](const EnvironmentNode *node,
tvm::IRPrinter *p) {
p->stream << "EnvironmentNode( " << node->functions << ")";
});
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/ir/expr.cc
* \brief The expression AST nodes of Relay.
*/
#include <tvm/ir_functor.h>
#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
using tvm::IRPrinter;
using namespace tvm::runtime;
Constant ConstantNode::make(runtime::NDArray data) {
std::shared_ptr<ConstantNode> n = std::make_shared<ConstantNode>();
n->data = std::move(data);
return Constant(n);
}
TVM_REGISTER_API("relay._make.Constant")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ConstantNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstantNode>([](const ConstantNode *node,
tvm::IRPrinter *p) {
p->stream << "ConstantNode(TODO)";
});
TensorType ConstantNode::tensor_type() const {
auto dtype = TVMType2Type(data->dtype);
Array<tvm::Expr> shape;
for (int i = 0; i < data->ndim; i++) {
shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), data->shape[i]));
}
return TensorTypeNode::make(shape, dtype);
}
Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
std::shared_ptr<TupleNode> n = std::make_shared<TupleNode>();
n->fields = std::move(fields);
return Tuple(n);
}
TVM_REGISTER_API("relay._make.Tuple")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TupleNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleNode>([](const TupleNode *node, tvm::IRPrinter *p) {
p->stream << "TupleNode(" << node->fields << ")";
});
Var VarNode::make(std::string name_hint) {
std::shared_ptr<VarNode> n = std::make_shared<VarNode>();
n->name_hint = std::move(name_hint);
return Var(n);
}
TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = VarNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<VarNode>([](const VarNode *node,
tvm::IRPrinter *p) {
p->stream << "VarNode(" << node->name_hint << ")";
});
GlobalVar GlobalVarNode::make(std::string name_hint) {
std::shared_ptr<GlobalVarNode> n = std::make_shared<GlobalVarNode>();
n->name_hint = std::move(name_hint);
return GlobalVar(n);
}
TVM_REGISTER_API("relay._make.GlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = GlobalVarNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const GlobalVarNode *node,
tvm::IRPrinter *p) {
p->stream << "GlobalVarNode(" << node->name_hint << ")";
});
Param ParamNode::make(Var var, Type type) {
std::shared_ptr<ParamNode> n = std::make_shared<ParamNode>();
n->var = std::move(var);
n->type = std::move(type);
return Param(n);
}
TVM_REGISTER_API("relay._make.Param")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ParamNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ParamNode>([](const ParamNode *node, tvm::IRPrinter *p) {
p->stream << "ParamNode(" << node->var << ", " << node->type << ")";
});
Function FunctionNode::make(tvm::Array<Param> params, Type ret_type, Expr body,
tvm::Array<TypeParam> type_params) {
std::shared_ptr<FunctionNode> n = std::make_shared<FunctionNode>();
n->params = std::move(params);
n->ret_type = std::move(ret_type);
n->body = std::move(body);
n->type_params = std::move(type_params);
return Function(n);
}
Type FunctionNode::fn_type() const {
Array<Type> param_types;
for (auto param : this->params) {
param_types.push_back(param->type);
}
return FuncTypeNode::make(param_types, this->ret_type, this->type_params, {});
}
TVM_REGISTER_API("relay._make.Function")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FunctionNode::make(args[0], args[1], args[2], args[3]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionNode>([](const FunctionNode *node,
tvm::IRPrinter *p) {
p->stream << "FunctionNode(" << node->params << ", " << node->ret_type
<< ", " << node->body << ", " << node->type_params << ")";
});
Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
Array<Type> type_args) {
std::shared_ptr<CallNode> n = std::make_shared<CallNode>();
n->op = std::move(op);
n->args = std::move(args);
n->attrs = std::move(attrs);
n->type_args = std::move(type_args);
return Call(n);
}
TVM_REGISTER_API("relay._make.Call")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CallNode::make(args[0], args[1], args[2], args[3]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode *node, tvm::IRPrinter *p) {
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
});
Let LetNode::make(Var var, Expr value, Expr body, Type value_type) {
std::shared_ptr<LetNode> n = std::make_shared<LetNode>();
n->var = std::move(var);
n->value = std::move(value);
n->body = std::move(body);
n->value_type = std::move(value_type);
return Let(n);
}
TVM_REGISTER_API("relay._make.Let")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = LetNode::make(args[0], args[1], args[2], args[3]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<LetNode>([](const LetNode *node, tvm::IRPrinter *p) {
p->stream << "LetNode(" << node->var << ", " << node->value
<< ", " << node->body << ", " << node->value_type << ")";
});
If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
std::shared_ptr<IfNode> n = std::make_shared<IfNode>();
n->cond = std::move(cond);
n->true_branch = std::move(true_branch);
n->false_branch = std::move(false_branch);
return If(n);
}
TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IfNode::make(args[0], args[1], args[2]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<IfNode>([](const IfNode *node, tvm::IRPrinter *p) {
p->stream << "IfNode(" << node->cond << ", " << node->true_branch
<< node->false_branch << ")";
});
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/expr_mutator.cc
* \brief A wrapper around ExprFunctor which functionally updates the AST.
*
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/
#include <tvm/relay/expr_functor.h>
namespace tvm {
namespace relay {
Expr ExprMutator::Mutate(const Expr& expr) {
auto cached_expr = this->memo_.find(expr);
if (cached_expr != this->memo_.end()) {
return (*cached_expr).second;
} else {
auto new_expr = this->ExprMutator::VisitExpr(expr, expr);
this->memo_.Set(expr, new_expr);
return new_expr;
}
}
Expr ExprMutator::VisitExpr_(const VarNode* op, const Expr& expr) {
return expr;
}
Expr ExprMutator::VisitExpr_(const ConstantNode* op, const Expr& expr) {
return expr;
}
Expr ExprMutator::VisitExpr_(const GlobalVarNode* op, const Expr& expr) {
return expr;
}
Expr ExprMutator::VisitExpr_(const OpNode* op, const Expr& expr) {
return expr;
}
Expr ExprMutator::VisitExpr_(const TupleNode* op, const Expr& e) {
tvm::Array<Expr> fields;
bool all_fields_unchanged = true;
for (auto field : op->fields) {
auto new_field = this->Mutate(field);
fields.push_back(new_field);
all_fields_unchanged &= new_field.same_as(field);
}
if (all_fields_unchanged) {
return e;
} else {
return TupleNode::make(fields);
}
}
Expr ExprMutator::VisitExpr_(const ParamNode* op, const Expr& e) {
Var var = Downcast<Var>(this->Mutate(op->var));
auto type = this->VisitType(op->type);
if (var == op->var && type == op->type) {
return e;
} else {
return ParamNode::make(var, type);
}
}
Expr ExprMutator::VisitExpr_(const FunctionNode* op, const Expr& e) {
tvm::Array<TypeParam> ty_params;
bool all_ty_params_changed = true;
for (auto ty_param : op->type_params) {
TypeParam new_ty_param = Downcast<TypeParam>(VisitType(ty_param));
ty_params.push_back(new_ty_param);
all_ty_params_changed &= new_ty_param.same_as(ty_param);
}
tvm::Array<Param> params;
bool all_params_changed = true;
for (auto param : op->params) {
Param new_param = Downcast<Param>(this->Mutate(param));
params.push_back(new_param);
all_params_changed &= param.same_as(new_param);
}
auto ret_type = this->VisitType(op->ret_type);
auto body = this->Mutate(op->body);
if (ty_params.same_as(op->type_params) && params.same_as(op->params) &&
ret_type.same_as(op->ret_type) && body.same_as(op->body)) {
return e;
} else {
return FunctionNode::make(params, ret_type, body, ty_params);
}
}
Expr ExprMutator::VisitExpr_(const CallNode* call_node, const Expr& e) {
auto op = this->Mutate(call_node->op);
tvm::Array<Type> ty_args;
bool all_ty_args_unchanged = true;
for (auto ty_arg : call_node->type_args) {
auto new_ty_arg = this->VisitType(ty_arg);
ty_args.push_back(new_ty_arg);
all_ty_args_unchanged &= new_ty_arg.same_as(ty_arg);
}
tvm::Array<Expr> call_args;
bool all_args_unchanged = true;
for (auto arg : call_node->args) {
auto new_arg = this->Mutate(arg);
call_args.push_back(new_arg);
all_args_unchanged &= new_arg.same_as(arg);
}
if (all_ty_args_unchanged && all_args_unchanged &&
call_node->op.same_as(op)) {
return e;
} else {
return CallNode::make(op, call_args, call_node->attrs, ty_args);
}
}
Expr ExprMutator::VisitExpr_(const LetNode* op, const Expr& e) {
Var var = Downcast<Var>(this->Mutate(op->var));
auto type = this->VisitType(op->value_type);
auto value = this->Mutate(op->value);
auto body = this->Mutate(op->body);
if (var.same_as(op->var) && type.same_as(op->value_type) &&
value.same_as(op->value) && body.same_as(op->body)) {
return e;
} else {
return LetNode::make(var, value, body, type);
}
}
Expr ExprMutator::VisitExpr_(const IfNode* op, const Expr& e) {
auto guard = this->Mutate(op->cond);
auto true_b = this->Mutate(op->true_branch);
auto false_b = this->Mutate(op->false_branch);
if (op->cond == guard && true_b == op->true_branch &&
false_b == op->false_branch) {
return e;
} else {
return IfNode::make(guard, true_b, false_b);
}
}
Type ExprMutator::VisitType(const Type& t) { return t; }
void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { return; }
void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) { return; }
void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) { return; }
void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) {
for (auto field : op->fields) {
this->VisitExpr(field);
}
}
void ExprVisitor::ExprVisitor::VisitExpr_(const ParamNode* op) {
this->VisitExpr(op->var);
}
void ExprVisitor::ExprVisitor::VisitExpr_(const FunctionNode* op) {
for (auto param : op->params) {
this->VisitExpr(param);
}
this->VisitExpr(op->body);
}
void ExprVisitor::VisitExpr_(const CallNode* op) {
this->VisitExpr(op->op);
for (auto ty_arg : op->type_args) {
this->VisitType(ty_arg);
}
for (auto arg : op->args) {
this->VisitExpr(arg);
}
}
void ExprVisitor::VisitExpr_(const LetNode* op) {
this->VisitExpr(op->var);
this->VisitExpr(op->value);
this->VisitExpr(op->body);
}
void ExprVisitor::VisitExpr_(const IfNode* op) {
this->VisitExpr(op->cond);
this->VisitExpr(op->true_branch);
this->VisitExpr(op->false_branch);
}
void ExprVisitor::VisitExpr_(const OpNode* op) { return; }
void ExprVisitor::VisitType(const Type& t) { return; }
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/op.cc
* \brief Resolve incomplete types to complete types.
*/
#include <tvm/relay/op.h>
#include <tvm/relay/type.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
#include <memory>
#include <mutex>
#include "./../pass/type_subst.h"
namespace dmlc {
// enable registry
DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry);
} // namespace dmlc
namespace tvm {
namespace relay {
::dmlc::Registry<OpRegistry>* OpRegistry::Registry() {
return ::dmlc::Registry<OpRegistry>::Get();
}
// single manager of operator information.
struct OpManager {
// mutex to avoid registration from multiple threads.
std::mutex mutex;
// global operator counter
std::atomic<int> op_counter{0};
// storage of additional attribute table.
std::unordered_map<std::string, std::unique_ptr<GenericOpMap>> attr;
// frontend functions
std::vector<PackedFunc*> frontend_funcs;
// get singleton of the op manager
static OpManager* Global() {
static OpManager inst;
return &inst;
}
};
// find operator by name
const Op& Op::Get(const std::string& name) {
const OpRegistry* reg = dmlc::Registry<OpRegistry>::Find(name);
CHECK(reg != nullptr) << "Operator " << name << " is not registered";
return reg->op();
}
OpRegistry::OpRegistry() {
OpManager* mgr = OpManager::Global();
std::shared_ptr<OpNode> n = std::make_shared<OpNode>();
n->index_ = mgr->op_counter++;
op_ = Op(n);
}
// Get attribute map by key
const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
auto it = mgr->attr.find(key);
if (it == mgr->attr.end()) {
LOG(FATAL) << "Operator attribute \'" << key << "\' is not registered";
}
return *it->second.get();
}
void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value,
int plevel) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
if (op_map == nullptr) {
op_map.reset(new GenericOpMap());
}
uint32_t index = op_->index_;
if (op_map->data_.size() <= index) {
op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0));
}
std::pair<TVMRetValue, int>& p = op_map->data_[index];
CHECK(p.second != plevel)
<< "Attribute " << key << " of operator " << this->name
<< " is already registered with same plevel=" << plevel;
if (p.second < plevel) {
op_map->data_[index] = std::make_pair(value, plevel);
}
}
// Frontend APIs
TVM_REGISTER_API("relay.op._ListOpNames")
.set_body_typed<Array<tvm::Expr>()>([]() {
Array<tvm::Expr> ret;
for (const std::string& name :
dmlc::Registry<OpRegistry>::ListAllNames()) {
ret.push_back(tvm::Expr(name));
}
return ret;
});
TVM_REGISTER_API("relay.op._GetOp").set_body_typed<Op(std::string)>(Op::Get);
TVM_REGISTER_API("relay.op._OpGetAttr")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Op op = args[0];
std::string attr_name = args[1];
auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
if (op_map.count(op)) {
*rv = op_map[op];
}
});
TVM_REGISTER_API("relay.op._Register")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::string op_name = args[0];
std::string attr_key = args[1];
runtime::TVMArgValue value = args[2];
int plevel = args[3];
auto& reg =
OpRegistry::Registry()->__REGISTER_OR_GET__(op_name).set_name();
// enable resgiteration and override of certain properties
if (attr_key == "num_inputs" && plevel > 128) {
reg.set_num_inputs(value);
} else if (attr_key == "attrs_type_key" && plevel > 128) {
reg.set_attrs_type_key(value);
} else {
// normal attr table override.
if (args[2].type_code() == kFuncHandle) {
// do an eager copy of the PackedFunc
PackedFunc f = args[2];
// If we get a function from frontend, avoid deleting it.
OpManager::Global()->frontend_funcs.push_back(new PackedFunc(f));
reg.set_attr(attr_key, f, plevel);
} else {
reg.set_attr(attr_key, args[2], plevel);
}
}
});
std::shared_ptr<OpNode> CreateOp(const std::string& name) {
auto op = Op::Get(name);
CHECK(!op.defined()) << "Cannot find op \'" << name << '\'';
std::shared_ptr<Node> node = op.node_;
return std::dynamic_pointer_cast<OpNode>(node);
}
TVM_REGISTER_NODE_TYPE(OpNode)
.set_creator(CreateOp)
.set_global_key([](const Node* n) {
return static_cast<const OpNode*>(n)->name;
});
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/ir/type.cc
* \brief The type system AST nodes of Relay.
*/
#include <tvm/ir_functor.h>
#include <tvm/relay/type.h>
namespace tvm {
namespace relay {
using tvm::IRPrinter;
using namespace tvm::runtime;
TensorType TensorTypeNode::make(Array<ShapeExpr> shape, DataType dtype) {
std::shared_ptr<TensorTypeNode> n = std::make_shared<TensorTypeNode>();
n->shape = std::move(shape);
n->dtype = std::move(dtype);
return TensorType(n);
}
TensorType TensorTypeNode::Scalar(DataType dtype) {
return TensorTypeNode::make({}, dtype);
}
TVM_REGISTER_API("relay._make.TensorType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Array<ShapeExpr> shape = args[0];
*ret = TensorTypeNode::make(shape, args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const TensorTypeNode *node,
tvm::IRPrinter *p) {
p->stream << "TensorTypeNode(" << node->dtype << ", " << node->shape << ")";
});
TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) {
std::shared_ptr<TypeParamNode> n = std::make_shared<TypeParamNode>();
n->var = tvm::Var(name);
n->kind = std::move(kind);
return TypeParam(n);
}
TVM_REGISTER_API("relay._make.TypeParam")
.set_body([](TVMArgs args, TVMRetValue *ret) {
int kind = args[1];
*ret =
TypeParamNode::make(args[0], static_cast<TypeParamNode::Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeParamNode>([](const TypeParamNode *node,
tvm::IRPrinter *p) {
p->stream << "TypeParamNode(" << node->var->name_hint << ", "
<< node->kind << ")";
});
FuncType FuncTypeNode::make(tvm::Array<Type> arg_types, Type ret_type,
tvm::Array<TypeParam> type_params,
tvm::Array<TypeConstraint> type_constraints) {
std::shared_ptr<FuncTypeNode> n = std::make_shared<FuncTypeNode>();
n->arg_types = std::move(arg_types);
n->ret_type = std::move(ret_type);
n->type_params = std::move(type_params);
n->type_constraints = std::move(type_constraints);
return FuncType(n);
}
TVM_REGISTER_API("relay._make.FuncType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const FuncTypeNode *node,
tvm::IRPrinter *p) {
p->stream << "FuncTypeNode(" << node->type_params << ", "
<< node->arg_types << ", " << node->ret_type << ", "
<< node->type_constraints << ")";
});
TypeRelation TypeRelationNode::make(std::string name, TypeRelationFn func, Array<Type> args) {
std::shared_ptr<TypeRelationNode> n = std::make_shared<TypeRelationNode>();
n->name = std::move(name);
n->func_ = std::move(func);
n->args = std::move(args);
return TypeRelation(n);
}
TVM_REGISTER_API("relay._make.TypeRelation")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TypeRelationNode::make(args[0], args[1], args[2]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeRelationNode>([](const TypeRelationNode *node,
tvm::IRPrinter *p) {
p->stream << "TypeRelationNode(" << node->name << ", " << node->args
<< ")";
});
TupleType TupleTypeNode::make(Array<Type> fields) {
std::shared_ptr<TupleTypeNode> n = std::make_shared<TupleTypeNode>();
n->fields = std::move(fields);
return TupleType(n);
}
TVM_REGISTER_API("relay._make.TupleType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TupleTypeNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleTypeNode>([](const TupleTypeNode *node,
tvm::IRPrinter *p) {
p->stream << "TupleTypeNode(" << node->fields << ")";
});
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file elemwise.cc
* \brief Elementwise operators.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include "../type_relations.h"
namespace tvm {
namespace relay {
// Quick helper macro
// - Expose a positional make function to construct the node.
// - Register op to the registry.
//
// We make the decision to always only expose positional argument.
// We will do rewrapping in the frontend to support language
// sugars such as keyword arguments and default value.
//
#define RELAY_REGISTER_UNARY_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \
.set_body_typed<Expr(Expr)>([](Expr data) { \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {data}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(1) \
.add_argument("data", "Tensor", "The input tensor.")
RELAY_REGISTER_UNARY_OP("log")
.describe(R"code(Returns the log input array, computed element-wise.
.. math::
log(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
// data : Tensor[shape, dtype]
// result: Tensor[shape, dtype]
RELAY_REGISTER_UNARY_OP("exp")
.describe(R"code(Returns the exp input array, computed element-wise.
.. math::
\exp(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
RELAY_REGISTER_UNARY_OP("sqrt")
.describe(R"code(Returns the sqrt input array, computed element-wise.
.. math::
sqrt(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);
// Addition
TVM_REGISTER_API("relay.op._make.add")
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) {
static const Op& op = Op::Get("add");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
});
RELAY_REGISTER_OP("add")
.set_num_inputs(2)
.add_argument("lhs", "Tensor", "The left hand side tensor.")
.add_argument("rhs", "Tensor", "The right hand side tensor.")
.set_support_level(1)
.add_type_rel("Broadcast", BroadcastRel);
// def broadcast(s1, s2):
// ...
//
// input1: Tensor[dtype, s1]
// input2: Tensor[dtype, s2]
// output: Tensor[dtype, broadcast(s1, s2)]
// Addition
TVM_REGISTER_API("relay.op._make.subtract")
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) {
static const Op& op = Op::Get("subtract");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
});
RELAY_REGISTER_OP("subtract")
.set_num_inputs(2)
.add_argument("lhs", "Tensor", "The left hand side tensor.")
.add_argument("rhs", "Tensor", "The right hand side tensor.")
.set_support_level(1)
.add_type_rel("Broadcast", BroadcastRel);
// def broadcast(s1, s2):
// ...
//
// input1: Tensor[dtype, s1]
// input2: Tensor[dtype, s2]
// output: Tensor[dtype, broadcast(s1, s2)]
// Equality Comparison
TVM_REGISTER_API("relay.op._make.equal")
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) {
static const Op& op = Op::Get("equal");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
});
RELAY_REGISTER_OP("equal")
.set_num_inputs(2)
.add_argument("lhs", "Tensor", "The left hand side tensor.")
.add_argument("rhs", "Tensor", "The right hand side tensor.")
.set_support_level(1)
.add_type_rel("BroadcastComp", BroadcastCompRel);
// Concat
TVM_REGISTER_API("relay.op._make.concat")
.set_body_typed<Expr(Expr)>([](Expr tuple) {
static const Op& op = Op::Get("concat");
return CallNode::make(op, { tuple }, Attrs(), {});
});
RELAY_REGISTER_OP("concat")
.set_num_inputs(1)
.add_argument("tuple", "Tuple", "The tupled tensor arguments.")
.set_support_level(1)
.add_type_rel("Concat", ConcatRel);
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file type_relations.cc
* \brief A set of utilities and common functionality
* for type relations.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/logging.h>
#include <tvm/relay/op.h>
#include <numeric>
#include "../pass/incomplete_type.h"
#include "./type_relations.h"
namespace tvm {
namespace relay {
TensorType ToTensorType(const Type& t) {
if (auto tt_node = t.as<TensorTypeNode>()) {
return GetRef<TensorType>(tt_node);
} else {
return TensorType(nullptr);
}
}
// TODO(@jroesch) what size value do we extract, 64bit or 32bit?
int ToInt(const tvm::Expr& e) {
CHECK(e.defined());
auto imm = e.as<tvm::ir::IntImm>();
CHECK(imm) << "TYPE: " << imm << imm->type << std::endl;
return imm->value;
}
Array<Type> IdentityRel(const Array<Type>& types, int num_args) {
CHECK_EQ(types.size(), 2);
auto t1 = ToTensorType(types[0]);
if (t1 && types[1].as<IncompleteTypeNode>()) {
return {t1, t1};
} else {
return types;
}
}
static Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2,
DataType output_dtype) {
RELAY_LOG(INFO) << "ConcreteBroadcast: t1=" << t1 << " t2=" << t2
<< std::endl;
auto sh1 = t1->shape;
auto sh2 = t2->shape;
RELAY_LOG(INFO) << "ConcreteBroadcast: sh1=" << sh1 << " sh2=" << sh2
<< std::endl;
if (sh1.size() == 0 && sh2.size() == 0) {
return TensorTypeNode::make({}, output_dtype);
// We have non-zero shapes so broadcast rules apply.
} else {
auto suffix_len = static_cast<int>(std::min(sh1.size(), sh2.size()));
auto full_len = static_cast<int>(std::max(sh1.size(), sh2.size()));
auto rev_sh1 = sh1.rbegin();
auto rev_sh2 = sh2.rbegin();
while (rev_sh1 != sh1.rend() && rev_sh2 != sh2.rend()) {
auto dim1 = ToInt(*rev_sh1);
auto dim2 = ToInt(*rev_sh2);
if ((dim1 != dim2) && ((dim1 != 1) && (dim2 != 1))) {
CHECK(false) << "Dimension mistmatch "
<< "dim1: " << dim1 << " dim2: " << dim2 << std::endl;
}
rev_sh1++;
rev_sh2++;
}
Array<ShapeExpr> larger;
Array<ShapeExpr> smaller;
for (int i = 0; i < (full_len - suffix_len); i++) {
smaller.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), 1));
}
if (sh1.size() < sh2.size()) {
for (auto sh : sh1) {
smaller.push_back(sh);
}
larger = sh2;
} else if (sh1.size() > sh2.size()) {
for (auto sh : sh1) {
larger.push_back(sh);
}
smaller = sh2;
} else {
larger = sh1;
smaller = sh2;
}
CHECK_EQ(larger.size(), smaller.size());
Array<HalideIR::Expr> out_shape;
for (size_t i = 0; i < smaller.size(); i++) {
auto left = smaller[i].as<tvm::ir::IntImm>();
auto right = larger[i].as<tvm::ir::IntImm>();
CHECK(left);
CHECK(right);
int64_t dim = std::max(left->value, right->value);
out_shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), dim));
}
return TensorTypeNode::make(out_shape, output_dtype);
}
}
Array<Type> BroadcastRel(const Array<Type>& types, int num_args) {
CHECK_EQ(types.size(), 3);
RELAY_LOG(INFO) << "In1: " << types[0] << "In2: " << types[1]
<< "Out: " << types[2] << std::endl;
if (auto t1 = ToTensorType(types[0])) {
if (auto t2 = ToTensorType(types[1])) {
CHECK_EQ(t1->dtype, t2->dtype);
return {t1, t2, ConcreteBroadcast(t1, t2, t1->dtype)};
}
}
return types;
}
/* A relation which specifies broadcasting rules for operations which
compute boolean results.
*/
Array<Type> BroadcastCompRel(const Array<Type>& types, int num_args) {
CHECK_EQ(types.size(), 3);
if (auto t1 = ToTensorType(types[0])) {
if (auto t2 = ToTensorType(types[1])) {
return {t1, t2, ConcreteBroadcast(t1, t2, HalideIR::Bool())};
}
}
return types;
}
/*! \brief Handle concrete concat case from known input to output. */
inline Type ConcreteConcatRel(const Type& input_type) {
if (auto tuple_node = input_type.as<TupleTypeNode>()) {
// NB: For now the axis argument is hardwired to be 0.
std::vector<int> dims;
DataType dtype;
CHECK_LT(1, tuple_node->fields.size());
bool skip_first = true;
// Collect the suffix dimensions since axis is zero.
// TODO(@jroesch): This is a demonstration of how
// to do varargs. It requires a little more work to
// fully type the behavior of concat.
auto first = Downcast<TensorType>(tuple_node->fields[0]);
dtype = first->dtype;
for (auto dim_expr : first->shape) {
if (!skip_first) {
dims.push_back(ToInt(dim_expr));
} else {
skip_first = false;
}
}
std::vector<int> axis_dims;
for (auto field_ty : tuple_node->fields) {
auto ttype = Downcast<TensorType>(field_ty);
for (size_t i = 0; i < ttype->shape.size(); i++) {
if (i != 0) {
CHECK_EQ(ToInt(dims[i - 1]), ToInt(ttype->shape[i]));
} else {
axis_dims.push_back(ToInt(ttype->shape[i]));
}
}
}
auto out_axis_dim = std::accumulate(axis_dims.begin(), axis_dims.end(), 0);
Array<tvm::Expr> out_shape = { tvm::ir::IntImm::make(HalideIR::Int(64), out_axis_dim) };
for (auto dim : dims) {
out_shape.push_back(tvm::ir::IntImm::make(HalideIR::Int(64), dim));
}
return TensorTypeNode::make(out_shape, dtype);
} else {
throw TypeRelationError("concat can only be used with a tuple as its argument");
}
}
Array<Type> ConcatRel(const Array<Type>& types, int num_args) {
CHECK_EQ(types.size(), 2);
if (types[0].as<IncompleteTypeNode>() && types[1].as<IncompleteTypeNode>()) {
return types;
} else if (types[1].as<IncompleteTypeNode>()) {
return { types[0], ConcreteConcatRel(types[0]) };
} else {
throw TypeRelationError(
"can not deduce relationship between the " \
"type of concat's input and output");
}
}
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/op/type_relations.h
* \brief A set of utilities and common functionality
* for type relations.
*/
#ifndef TVM_RELAY_OP_TYPE_RELATIONS_H_
#define TVM_RELAY_OP_TYPE_RELATIONS_H_
#include <tvm/relay/error.h>
#include <tvm/relay/type.h>
#include <string>
namespace tvm {
namespace relay {
/*! \brief The error raised by a type relation.
*
* This error is how a type relation signals that it has failed.
*
*/
struct TypeRelationError : Error {
explicit TypeRelationError(const std::string& msg)
: Error(msg) {}
};
/*! \brief The identity type relation maps a single input variable
* to the output variable.
*
* \param types The input and output types to the relation.
* \param num_args The number of input arguments.
* \return The (potentially partial) solution to the relation.
*/
Array<Type> IdentityRel(const Array<Type>& types, int num_args);
/*! \brief The broadcast type relation, implements the broadcasting
* rule over the two input types producing the broadcasted type.
*
* \param types The input and output types to the relation.
* \param num_args The number of input arguments.
* \return The (potentially partial) solution to the relation.
*/
Array<Type> BroadcastRel(const Array<Type>& types, int num_args);
/*! \brief The broadcast type relation, implements the broadcasting
* rule over the two input types producing the broadcasted type.
*
* This differs from BroadcastRel in the return dtype,
* it instead returns bool, for use in comparsion operators
* such as equal, not_equal, lt, and so on.
*
* \param types The input and output types to the relation.
* \param num_args The number of input arguments.
* \return The (potentially partial) solution to the relation.
*/
Array<Type> BroadcastCompRel(const Array<Type>& types, int num_args);
/*! \brief The concat relation.
*
* This relation takes a single input which must be a single tensor
* or an arbitrary sized tuple. It combines these input dimensions
* together to produce the output example.
*/
Array<Type> ConcatRel(const Array<Type>& types, int num_args);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_TYPE_RELATIONS_H_
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/pass/alpha_eq.cc
* \brief Compute the set of variables not bound in the expression.
*/
#include <tvm/relay/expr_functor.h>
#include "./type_visitor.h"
#include "tvm/relay/pass.h"
namespace tvm {
namespace relay {
using namespace tvm::runtime;
struct TypeAlphaEq : TypeVisitor<const Type&> {
tvm::Map<TypeParam, TypeParam> eq_map;
bool equal;
TypeAlphaEq() : eq_map(), equal(true) {}
void DataTypeEqual(const DataType& dt1, const DataType& dt2) {
equal = equal && dt1 == dt2;
}
void ShapeEqual(Array<ShapeExpr> s1, Array<ShapeExpr> s2) {}
void VisitType_(const TensorTypeNode *tt1, const Type& t2) final {
if (const TensorTypeNode *tt2 = t2.as<TensorTypeNode>()) {
DataTypeEqual(tt1->dtype, tt2->dtype);
ShapeEqual(tt1->shape, tt2->shape);
} else {
equal = false;
}
}
void VisitType_(const IncompleteTypeNode *bt1, const Type& t2) final {
if (const IncompleteTypeNode *bt2 = t2.as<IncompleteTypeNode>()) {
equal = equal && bt1 == bt2;
return;
} else {
equal = false;
}
}
void VisitType_(const TypeParamNode *ti1, const Type& t2) final {
if (const TypeParamNode *ti2 = t2.as<TypeParamNode>()) {
auto tid1 = GetRef<TypeParam>(ti1);
auto tid2 = GetRef<TypeParam>(ti2);
// We handle open terms with this rule assuming variables are identical.
//
// Not sure if we should do this.
if (tid1 == tid2) {
return;
}
// Check that they are same kind
if (tid1->kind != tid2->kind) {
equal = false;
return;
}
// Next we see if there is mapping for local1 into the rhs term.
// If there is we check to see if those are equal.
if (eq_map.find(tid1) != eq_map.end()) {
equal = equal && eq_map[tid1] == tid2;
} else {
equal = false;
}
} else {
equal = false;
}
}
void VisitType_(const FuncTypeNode *op, const Type& t2) final {
if (const FuncTypeNode *ta2 = t2.as<FuncTypeNode>()) {
if (op->arg_types.size() != ta2->arg_types.size()) {
equal = false;
return;
}
for (size_t i = 0; i < op->arg_types.size(); i++) {
this->VisitType(op->arg_types[i], ta2->arg_types[i]);
if (!equal) {
return;
}
}
this->VisitType(op->ret_type, ta2->ret_type);
} else {
equal = false;
}
}
void VisitType_(const TypeRelationNode *tr1, const Type& t2) final {
if (const TypeRelationNode *tr2 = t2.as<TypeRelationNode>()) {
equal = tr1 == tr2;
} else {
equal = false;
}
}
void VisitType_(const TupleTypeNode *op, const Type& t2) final {
if (const TupleTypeNode *pt = t2.as<TupleTypeNode>()) {
if (op->fields.size() != pt->fields.size()) {
equal = false;
return;
}
for (size_t i = 0U; i < op->fields.size(); i++) {
if (!equal) {
return;
}
this->VisitType(op->fields[i], pt->fields[i]);
}
} else {
equal = false;
}
}
};
bool AlphaEqual(const Type& t1, const Type& t2) {
TypeAlphaEq aeq;
aeq.VisitType(t1, t2);
return aeq.equal;
}
struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
public:
tvm::Map<Var, Var> eq_map;
bool equal;
AlphaEq() : eq_map(), equal(true) {}
void VisitExpr_(const VarNode *e1, const Expr& e2) final {
if (const VarNode *id2 = e2.as<VarNode>()) {
auto local1 = GetRef<Var>(e1);
auto local2 = GetRef<Var>(id2);
// We handle open terms with this rule assuming variables are identical.
if (local1 == local2) {
equal = true;
return;
}
// Next we see if there is mapping for local1 into the rhs term.
// If there is we check to see if those are equal.
if (eq_map.find(local1) != eq_map.end()) {
equal = equal && eq_map[local1] == local2;
} else {
equal = false;
}
} else {
equal = false;
}
}
void VisitExpr_(const GlobalVarNode *g1, const Expr& e2) final {
if (const GlobalVarNode *g2 = e2.as<GlobalVarNode>()) {
equal = equal && g1 == g2;
} else {
equal = false;
}
}
void VisitExpr_(const TupleNode *pl1, const Expr& e2) final {
Tuple prod1 = GetRef<Tuple>(pl1);
if (const TupleNode *pl2 = e2.as<TupleNode>()) {
Tuple prod2 = GetRef<Tuple>(pl2);
if (prod1->fields.size() != prod2->fields.size()) {
equal = false;
return;
}
for (size_t i = 0U; i < prod1->fields.size(); i++) {
this->VisitExpr(prod1->fields[i], prod2->fields[i]);
}
} else {
equal = false;
}
}
void VisitExpr_(const ParamNode *p1, const Expr& e2) final {
if (const ParamNode *p2 = e2.as<ParamNode>()) {
eq_map.Set(p1->var, p2->var);
equal = equal && AlphaEqual(p1->type, p2->type);
} else {
equal = false;
}
}
void VisitExpr_(const FunctionNode *func1, const Expr& e2) final {
if (const FunctionNode *func2 = e2.as<FunctionNode>()) {
if (func1->params.size() != func2->params.size()) {
equal = false;
return;
}
for (size_t i = 0U; i < func1->params.size(); i++) {
this->VisitExpr(func1->params[i], func2->params[i]);
}
this->VisitExpr(func1->body, func2->body);
} else {
equal = false;
}
}
void VisitExpr_(const CallNode *op, const Expr& e2) final {
if (const CallNode *call = e2.as<CallNode>()) {
this->VisitExpr(op->op, call->op);
if (op->args.size() != call->args.size()) {
equal = false;
return;
}
for (size_t i = 0U; i < op->args.size(); i++) {
this->VisitExpr(op->args[i], call->args[i]);
}
} else {
equal = false;
}
}
void VisitExpr_(const LetNode *op, const Expr& e2) final {
if (const LetNode *let = e2.as<LetNode>()) {
eq_map.Set(op->var, let->var);
this->VisitExpr(op->value, let->value);
this->VisitExpr(op->body, let->body);
} else {
equal = false;
}
}
};
bool AlphaEqual(const Expr& e1, const Expr& e2) {
AlphaEq eq;
eq.VisitExpr(e1, e2);
return eq.equal;
}
// TODO(@jroesch): move to correct namespace?
TVM_REGISTER_API("relay._make._alpha_eq")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Expr e1 = args[0];
Expr e2 = args[1];
*ret = AlphaEqual(e1, e2);
});
TVM_REGISTER_API("relay._make._type_alpha_eq")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Type t1 = args[0];
Type t2 = args[1];
*ret = AlphaEqual(t1, t2);
});
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file incomplete_type.h
* \brief A way to defined arbitrary function signature with dispatch on types.
*/
#ifndef TVM_RELAY_PASS_INCOMPLETE_TYPE_H_
#define TVM_RELAY_PASS_INCOMPLETE_TYPE_H_
#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
/*!
* \brief Represents a portion of an incomplete type.
*/
class IncompleteType;
/*! \brief IncompleteType container node */
class IncompleteTypeNode : public TypeNode {
public:
TypeParamNode::Kind kind;
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("kind", &kind); }
TVM_DLL static IncompleteType make(TypeParamNode::Kind kind);
static constexpr const char* _type_key = "relay.IncompleteType";
TVM_DECLARE_NODE_TYPE_INFO(IncompleteTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(IncompleteType, IncompleteTypeNode, Type);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_INCOMPLETE_TYPE_H_
/*!
* Copyright (c) 2018 by Contributors
*
* \file kindchecker.cc
*
* \brief Check that types are well formed by applying "kinding rules".
*
* This pass ensures we do not do things that violate the design of the
* type system when writing down types.
*
* For example tensors are not allowed to contain functions in Relay.
*
* We check this by ensuring the `dtype` field of a Tensor always
* contains a data type such as `int`, `float`, `uint`.
*/
#include <tvm/ir_functor.h>
#include <tvm/relay/pass.h>
#include "./type_visitor.h"
namespace tvm {
namespace relay {
using namespace tvm::runtime;
struct KindChecker : TypeVisitor<> {
bool valid;
KindChecker() : valid(true) {}
bool Check(const Type &t) {
this->VisitType(t);
return valid;
}
};
bool KindCheck(const Environment& env, const Type &t) {
KindChecker kc;
return kc.Check(t);
}
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file resolve.cc
* \brief Resolve incomplete types to complete types.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include "./resolve.h"
#include "./type_visitor.h"
namespace tvm {
namespace relay {
struct ResolveTypeType : TypeMutator {
const TypeUnifier &unifier;
explicit ResolveTypeType(const TypeUnifier &unifier) : unifier(unifier) {}
Type VisitType(const Type &t) override {
if (!t.defined()) {
auto inc_ty = IncompleteTypeNode::make(TypeParamNode::Kind::kType);
unifier->Insert(inc_ty);
return inc_ty;
} else {
return TypeMutator::VisitType(t);
}
}
Type VisitType_(const IncompleteTypeNode *op) override {
return unifier->Subst(GetRef<IncompleteType>(op));
}
};
struct ResolveTypeExpr : ExprMutator {
const TypeUnifier &unifier;
explicit ResolveTypeExpr(const TypeUnifier &unifier) : unifier(unifier) {}
Expr Mutate(const Expr &e) {
// NB: a bit tricky here.
//
// We want to store resolved type without having
// to re-typecheck the entire term.
//
// Since we know that e : T[...] under some holes
// then it is the case that if we resolve types
// present in e, then we can type it under T
// with the wholes filled in.
//
// We will visit e like normal building a new
// term, then resolve e's old type and write
// it back into the new node.
auto new_e = ExprMutator::Mutate(e);
CHECK(e->checked_type_.defined());
auto resolved_cty = VisitType(e->checked_type_);
new_e->checked_type_ = resolved_cty;
return new_e;
}
Type VisitType(const Type &t) {
return ResolveTypeType(unifier).VisitType(t);
}
};
Type Resolve(const TypeUnifier &unifier, const Type &ty) {
CHECK(ty.defined());
return ResolveTypeType(unifier).VisitType(ty);
}
Expr Resolve(const TypeUnifier &unifier, const Expr &expr) {
return ResolveTypeExpr(unifier).Mutate(expr);
}
struct FullyResolved : TypeVisitor<> {
bool incomplete;
FullyResolved() : incomplete(true) {}
void VisitType(const Type &t) override {
if (!t.defined()) {
incomplete = true;
} else {
return TypeVisitor<>::VisitType(t);
}
}
void VisitType_(const IncompleteTypeNode *ty_var) override {
incomplete = false;
}
};
bool IsFullyResolved(const Type &t) {
auto fr = FullyResolved();
fr.VisitType(t);
return fr.incomplete;
}
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/resolve.h
* \brief Resolve incomplete types to complete types.
*/
#ifndef TVM_RELAY_PASS_RESOLVE_H_
#define TVM_RELAY_PASS_RESOLVE_H_
#include <tvm/relay/expr.h>
#include <string>
#include "./unifier.h"
namespace tvm {
namespace relay {
/*! \brief Resolve a type containing incomplete types.
*
* This pass replaces incomplete types with their representative, and
* converts types which are not defined into fresh variables.
*
* \param unifier The unifier containing the unification data.
* \param ty The type to resolve.
* \returns The resolved type.
*/
Type Resolve(const TypeUnifier& unifier, const Type& ty);
/*! \brief Resolve an expression containing incomplete types.
*
* This pass replaces incomplete types with their representative, and
* converts types which are not defined into fresh variables.
*
* \param unifier The unifier containing the unification data.
* \param ty The expression to resolve.
* \returns The resolved expression.
*/
Expr Resolve(const TypeUnifier& unifier, const Expr& expr);
/*! \brief Check if all types have been filled in.
* \param t The type.
* \returns True if the type is resolved, false otherwise.
*/
bool IsFullyResolved(const Type& t);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_RESOLVE_H_
/*!
* Copyright (c) 2018 by Contributors
* \file type_functor.h
* \brief A way to defined arbitrary function signature with dispatch on types.
*/
#ifndef TVM_RELAY_PASS_TYPE_FUNCTOR_H_
#define TVM_RELAY_PASS_TYPE_FUNCTOR_H_
#include <tvm/ir_functor.h>
#include <tvm/relay/expr.h>
#include "./incomplete_type.h"
namespace tvm {
namespace relay {
template <typename FType>
class TypeFunctor;
// functions to be overriden.
#define TYPE_FUNCTOR_DEFAULT \
{ return VisitTypeDefault_(op, std::forward<Args>(args)...); }
#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitType_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
});
template <typename R, typename... Args>
class TypeFunctor<R(const Type& n, Args...)> {
private:
using TSelf = TypeFunctor<R(const Type& n, Args...)>;
using FType = tvm::IRFunctor<R(const NodeRef& n, TSelf* self, Args...)>;
public:
/*! \brief the result type of this functor */
using result_type = R;
/*! \brief virtual destructor */
virtual ~TypeFunctor() {}
/*!
* \brief Same as call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
R operator()(const Type& n, Args... args) {
return VisitType(n, std::forward<Args>(args)...);
}
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitType(const Type& n, Args... args) {
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
// Functions that can be overriden by subclass
virtual R VisitType_(const TensorTypeNode* op,
Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeParamNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeConstraintNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const FuncTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitTypeDefault_(const Node* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->type_key();
return R();
}
private:
// initialize the vtable.
static FType InitVTable() {
FType vtable;
// Set dispatch
RELAY_TYPE_FUNCTOR_DISPATCH(TensorTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeParamNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeConstraintNode);
RELAY_TYPE_FUNCTOR_DISPATCH(FuncTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
return vtable;
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_TYPE_FUNCTOR_H_
/*!
* Copyright (c) 2018 by Contributors
* \file type_subst.cc
* \brief Function for substituting a concrete type in place of a type ID
*/
#include "./type_subst.h"
#include "./type_visitor.h"
namespace tvm {
namespace relay {
struct TypeSubstV : TypeMutator {
tvm::Map<TypeParam, Type> subst_map;
explicit TypeSubstV(tvm::Map<TypeParam, Type> subst_map)
: subst_map(subst_map) {}
Type VisitType_(const TypeParamNode* op) override {
auto id = GetRef<TypeParam>(op);
if (subst_map.find(id) != subst_map.end()) {
return this->subst_map[id];
} else {
return id;
}
}
};
Type TypeSubst(const Type& type, const TypeParam& target, const Type& subst) {
TypeSubstV ty_sub({ {target, subst} });
return ty_sub.VisitType(type);
}
Type TypeSubst(const Type& type, tvm::Map<TypeParam, Type> subst_map) {
TypeSubstV ty_sub(subst_map);
return ty_sub.VisitType(type);
}
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/pass/type_subst.h
* \brief Utility functions for substituting types.
*/
#ifndef TVM_RELAY_PASS_TYPE_SUBST_H_
#define TVM_RELAY_PASS_TYPE_SUBST_H_
#include <tvm/relay/expr.h>
namespace tvm {
namespace relay {
Type TypeSubst(const Type& type, const TypeParam& target, const Type& subst);
Type TypeSubst(const Type& type, tvm::Map<TypeParam, Type> subst_map);
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_TYPE_SUBST_H_
/*!
* Copyright (c) 2018 by Contributors
* \file type_visitor.h
* \brief A wrapper around TypeFunctor for common use cases.
*/
#ifndef TVM_RELAY_PASS_TYPE_VISITOR_H_
#define TVM_RELAY_PASS_TYPE_VISITOR_H_
#include <vector>
#include "./type_functor.h"
namespace tvm {
namespace relay {
/*! \brief A type visitor for vistiors which make use of internal
* mutable state.
*
* We recursively visit each type contained inside the visitor.
*/
template <typename... Args>
struct TypeVisitor : ::tvm::relay::TypeFunctor<void(const Type& n, Args...)> {
void VisitType_(const TypeParamNode* op, Args... args) override {}
void VisitType_(const FuncTypeNode* op, Args... args) override {
for (auto type_param : op->type_params) {
this->VisitType(type_param, std::forward<Args>(args)...);
}
for (auto type_cs : op->type_constraints) {
this->VisitType(type_cs, std::forward<Args>(args)...);
}
for (auto arg_type : op->arg_types) {
this->VisitType(arg_type, std::forward<Args>(args)...);
}
this->VisitType(op->ret_type, std::forward<Args>(args)...);
}
void VisitType_(const TensorTypeNode* op, Args... args) override {}
void VisitType_(const TupleTypeNode* op, Args... args) override {
for (const Type& t : op->fields) {
this->VisitType(t, std::forward<Args>(args)...);
}
}
void VisitType_(const TypeRelationNode* op, Args... args) override {
for (const Type& t : op->args) {
this->VisitType(t, std::forward<Args>(args)...);
}
}
void VisitType_(const IncompleteTypeNode* op, Args... args) override {}
};
// A functional visitor for rebuilding an AST in place.
struct TypeMutator : TypeFunctor<Type(const Type& n)> {
Type VisitType_(const TensorTypeNode* op) override {
// TODO(@jroesch): maybe we should recursively visit
return TensorTypeNode::make(op->shape, op->dtype);
}
Type VisitType_(const TypeParamNode* op) override {
return GetRef<TypeParam>(op);
}
Type VisitType_(const FuncTypeNode* op) override {
Array<TypeParam> type_params;
for (auto type_param : op->type_params) {
auto new_type_param = VisitType(type_param);
if (const TypeParamNode* tin = new_type_param.as<TypeParamNode>()) {
type_params.push_back(GetRef<TypeParam>(tin));
} else {
CHECK(false) << new_type_param << std::endl;
}
}
Array<TypeConstraint> type_constraints;
for (auto type_cs : op->type_constraints) {
auto new_type_cs = VisitType(type_cs);
if (const TypeConstraintNode* tin = As<TypeConstraintNode>(new_type_cs)) {
type_constraints.push_back(GetRef<TypeConstraint>(tin));
} else {
CHECK(false) << new_type_cs << std::endl;
}
}
std::vector<Type> args;
for (auto arg_type : op->arg_types) {
args.push_back(VisitType(arg_type));
}
return FuncTypeNode::make(tvm::Array<Type>(args), VisitType(op->ret_type),
type_params, type_constraints);
}
Type VisitType_(const TupleTypeNode* op) override {
std::vector<Type> new_fields;
for (const Type& t : op->fields) {
new_fields.push_back(this->VisitType(t));
}
return TupleTypeNode::make(new_fields);
}
Type VisitType_(const TypeRelationNode* type_rel) override {
std::vector<Type> new_args;
for (const Type& t : type_rel->args) {
new_args.push_back(this->VisitType(t));
}
return TypeRelationNode::make(type_rel->name, type_rel->func_, new_args);
}
Type VisitType_(const IncompleteTypeNode* op) override {
return GetRef<IncompleteType>(op);
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_TYPE_VISITOR_H_
/*!
* Copyright (c) 2018 by Contributors
* \file include/tvm/relay/pass/unifier.h
* \brief The type unifier which solves a system of equations between
* incomplete types.
*/
#ifndef TVM_RELAY_PASS_UNIFIER_H_
#define TVM_RELAY_PASS_UNIFIER_H_
#include <tvm/relay/expr.h>
#include <string>
#include "./type_functor.h"
namespace tvm {
namespace relay {
struct UnionFindError : dmlc::Error {
explicit UnionFindError(const std::string& msg) : Error(msg) {}
};
struct UnificationError : dmlc::Error {
explicit UnificationError(const std::string& msg) : Error(msg) {}
};
struct SubstitutionError : dmlc::Error {
explicit SubstitutionError(const std::string& msg) : Error(msg) {}
};
/*! \brief A union-find data structure for the type-checker */
class UnionFind;
class UnionFindNode : public Node {
public:
/*! \brief The inernal map from incomplete types to their representatives. */
tvm::Map<IncompleteType, Type> uf_map;
UnionFindNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("uf_map", &uf_map); }
TVM_DLL static UnionFind make(tvm::Map<IncompleteType, Type> uf_map);
/*! \brief Insert it into the union find.
* \param it The type to add to the union find.
*/
void Insert(const IncompleteType& it);
/*! \brief Union operation, combine two equivalence classes.
* \param it The incomplete type to unify.
* \param ty The other type.
*/
void Unify(const IncompleteType& it, const Type& t);
/*! \brief Find operation, returns the representative of the argument.
* \param it The element to lookup.
*/
Type Find(const IncompleteType& it);
void debug();
void AssertAlphaEqual(const Type& l, const Type& r);
static constexpr const char* _type_key = "relay.UnionFind";
TVM_DECLARE_NODE_TYPE_INFO(UnionFindNode, Node);
};
class UnionFind : public NodeRef {
public:
UnionFind() {}
explicit UnionFind(std::shared_ptr<tvm::Node> p) : NodeRef(p) {}
// The union find structure is mutable so we do not use the standard macros
// and expose the pointer via `->`.
UnionFindNode* operator->() const {
return static_cast<UnionFindNode*>(node_.get());
}
using ContainerType = UnionFindNode;
};
class TypeUnifier;
class TypeUnifierNode : public Node,
private TypeFunctor<Type(const Type&, const Type)> {
public:
UnionFind union_find;
TypeUnifierNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("union_find", &union_find); }
TVM_DLL static TypeUnifier make(UnionFind uf);
/*! \brief Introduces a new type var into the unifier */
void Insert(const IncompleteType& v);
/*! \brief Unifies two types if possible, throws a unification error if it
* cannot */
Type Unify(const Type& t1, const Type& t2);
/*! \brief Attempts to substitute all type vars in t with concrete types,
* throws substitution error if it cannot concretize*/
Type Subst(const Type& t);
// /*! \brief Checks the kinds in the given type */
// Type CheckKinds(const Type& t);
static constexpr const char* _type_key = "relay.TypeUnifier";
TVM_DECLARE_NODE_TYPE_INFO(TypeUnifierNode, Node);
private:
/*! \brief Unify incomplete type with another type. */
Type UnifyWithIncompleteType(const Type& t1, const IncompleteType tvn2);
/*! \brief Implements unification between two types with incomplete portions.
*/
Type VisitType(const Type& t1, const Type t2) override;
// Visitor Cases
Type VisitType_(const IncompleteTypeNode* t1, const Type t2) override;
Type VisitType_(const TensorTypeNode* t1, const Type t2) override;
Type VisitType_(const TypeParamNode* t1, const Type t2) override;
Type VisitType_(const FuncTypeNode* t1, const Type t2) override;
Type VisitType_(const TupleTypeNode* t1, const Type t2) override;
Type VisitType_(const TypeRelationNode* s1, const Type t2) override;
};
class TypeUnifier : public NodeRef {
public:
TypeUnifier() {}
explicit TypeUnifier(std::shared_ptr<tvm::Node> p) : NodeRef(p) {}
// no const so that unifier can be mutable as a member of typechecker
inline TypeUnifierNode* operator->() const {
return static_cast<TypeUnifierNode*>(node_.get());
}
using ContainerType = TypeUnifierNode;
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_UNIFIER_H_
import numpy as np
from tvm.relay.expr import Let, Constant
from tvm.relay.ir_builder import IRBuilder
def test_let():
b = IRBuilder()
x = b.let('x', 1)
b.ret(x)
prog, _ = b.get()
assert isinstance(prog, Let)
var = prog.var
value = prog.value
assert var.name_hint == 'x'
assert var == prog.body
assert isinstance(value, Constant)
assert value.data.asnumpy() == np.array(1)
assert prog.value_type == None
if __name__ == "__main__":
test_let()
""" test ir"""
import tvm
from tvm import relay
from tvm.expr import *
# Span
def test_span():
span = relay.Span(None, 1, 1)
assert span.source == None
assert span.lineno == 1
assert span.col_offset == 1
assert span.same_as(span)
assert span == span
assert isinstance(span, relay.base.Span)
str(span)
# Types
def test_tensor_type():
shape = tvm.convert([1, 2, 3])
dtype = 'float32'
tt = relay.TensorType(shape, dtype)
assert tt.dtype == dtype
assert tt.shape == shape
assert tt.span == None
str(tt)
def test_type_param():
tp = relay.TypeParam('name', relay.Kind.Shape)
tp.kind == relay.Kind.Shape
tp.span # TODO allow us to set span
str(tp)
def test_func_type():
type_params = tvm.convert([])
type_constraints = tvm.convert([]) # TODO: fill me in
arg_types = tvm.convert([])
ret_type = None
tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
assert tf.type_params == type_params
assert tf.type_constraints == type_constraints
assert tf.arg_types == arg_types
assert tf.ret_type == ret_type
assert tf.span == None
# TODO make sure we can set
str(tf)
def test_constant():
arr = tvm.nd.array(10)
const = relay.Constant(arr)
assert const.data == arr
assert const.span == None
str(const)
def test_tuple():
fields = tvm.convert([])
tup = relay.Tuple(fields)
assert tup.fields == fields
assert tup.span == None
str(tup)
def test_local_var():
name_hint = 's'
lv = relay.Var(name_hint)
lv.name_hint == name_hint
# assert lv.span == None todo(@jroesch): what do we do about spans
str(lv)
def test_global_var():
name_hint = 'g'
gv = relay.GlobalVar(name_hint)
gv.name_hint == name_hint
# assert lv.span == None todo(@jroesch): what do we do about spans
str(gv)
def test_param():
lv = relay.Var('x')
ty = None
param = relay.Param(lv, ty)
assert param.var == lv
assert param.type == ty
assert param.span == None
str(param)
def test_function():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Param(relay.Var(n), None) for n in param_names])
ret_type = None
body = None
type_params = tvm.convert([])
fn = relay.Function(params, ret_type, body, type_params)
assert fn.params == params
assert fn.body == body
assert fn.type_params == type_params
assert fn.span == None
str(fn)
def test_call():
op = relay.Var('f')
arg_names = ['a', 'b', 'c', 'd']
args = tvm.convert([relay.Var(n) for n in arg_names])
call = relay.Call(op, args, None, None)
assert call.op == op
assert call.args == args
assert call.span == None
str(call)
def test_let():
lv = relay.Var('x')
ty = None
arr = tvm.nd.array(10)
value = relay.Constant(arr)
# I would prefer that the order of arguments
# matches syntax let x: t = v in b
let = relay.Let(lv, value, lv, ty)
assert let.var == lv
assert let.value == value
assert let.value_type == ty
assert let.body == lv
assert let.span == None
str(let)
def test_if():
cond = relay.Var('cond')
left = relay.Var('left')
right = relay.Var('right')
ife = relay.If(cond, left, right)
assert ife.cond == cond
assert ife.true_branch == left
assert ife.false_branch == right
assert ife.span == None
str(ife)
if __name__ == "__main__":
test_span()
test_tensor_type()
test_type_param()
test_func_type()
test_constant()
test_tuple()
test_local_var()
test_global_var()
test_param()
test_function()
test_call()
test_let()
test_if()
from tvm import relay
def test_op_attr():
log_op = relay.op.get("log")
@relay.op.register("exp", "ftest")
def test(x):
return x + 1
assert log_op.num_inputs == 1
assert log_op.get_attr("ftest") is None
assert relay.op.get("exp").get_attr("ftest")(1) == 2
def test_op_level1():
x = relay.Var("x")
for op_name in ["log", "exp", "sqrt"]:
y = getattr(relay, op_name)(x)
assert y.op.name == op_name
assert y.op.support_level == 1
assert y.args[0] == x
if __name__ == "__main__":
test_op_attr()
test_op_level1()
"""Test that type checker correcly computes types
for expressions.
"""
import tvm
import numpy as np
from tvm.relay.ir_pass import check_expr
from tvm.relay.ir_builder import IRBuilder, func_type
from tvm.relay.ir_builder import scalar_type, convert, tensor_type
from tvm.relay.env import Environment
from tvm.relay.op import log, add, equal, subtract, concat
from tvm.relay.expr import Function
def assert_has_type(expr, typ, env=Environment({})):
checked_expr = check_expr(env, expr)
assert checked_expr.checked_type() == typ
def assert_decl_has_type(env, name, typ):
func = env[name]
assert func.checked_type() == typ
def test_monomorphic_let():
"Program: let x = 1; return x"
b = IRBuilder()
x = b.let('x', 1.0, value_type=scalar_type('float64'))
b.ret(x)
prog, env = b.get()
assert_has_type(prog, scalar_type('float64'))
def test_single_op():
"Program: fn (x : float32) { let t1 = f(x); t1 }"
b = IRBuilder()
with b.function(('x', 'float32')) as func:
x, = func.param_ids()
t1 = b.let('t1', log(x))
b.ret(t1)
assert_has_type(func.to_func(), func_type(['float32'], 'float32'))
def test_add_op():
"""
Program:
fn (x, y) {
return x + y;
}
"""
b = IRBuilder()
x = b.param('x', tensor_type(5, 5, 5))
y = b.param('y', tensor_type(5, 5, 5))
with b.function(x, y) as func:
b.ret(add(x.var, y.var))
b.ret(func)
prog, env = b.get()
ttype = tensor_type(5, 5, 5)
expected_ty = func_type([ttype, ttype], ttype)
assert_has_type(func.to_func(), expected_ty)
def test_add_broadcast_op():
"""
Program:
fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] {
return x + y;
}
"""
b = IRBuilder()
x = b.param('x', tensor_type(10, 4))
y = b.param('y', tensor_type(5, 10, 1))
with b.function(x, y) as func:
b.ret(add(x.var, y.var))
b.ret(func)
prog, env = b.get()
ttype = tensor_type(5, 5, 5)
expected_ty = func_type([ttype, ttype], ttype)
assert_has_type(func.to_func(), expected_ty)
def test_dual_op():
"""Program:
fn (x : Tensor[f32, (10, 10)]) {
let t1 = log(x);
let t2 = add(t1, x);
return t1;
}
"""
b = IRBuilder()
with b.function(('x', tensor_type(10, 10))) as func:
x, = func.param_ids()
t1 = b.let('t1', log(x))
t2 = b.let('t2', add(t1, x))
b.ret(t2)
assert_has_type(func.to_func(), func_type(['float32'], 'float32'))
def test_decl():
"""Program:
def f(x : Tensor[f32, (10, 10)]) {
let lx = log(x);
return lx;
}
"""
b = IRBuilder()
x = b.param('x')
with b.decl('f', x):
lx = b.let('lx', log(x))
b.ret(lx)
_, env = b.get()
assert_decl_has_type(env, 'f', func_type(['float32'], 'float32'))
def test_recursion():
"""
Program:
def f(n: i32, data: f32) -> f32 {
if (n == 0) {
return f(n - 1, log(data));
} else {
return data;
}
}
f(2, 10000);
"""
b = IRBuilder()
f = b.global_var('f')
n = b.param('n', ty='int32')
data = b.param('data', ty='float32')
with b.decl(f, n, data):
with b.if_scope(equal(n, convert(0.0))):
b.ret(f(subtract(n, convert(1)), log(data)))
with b.else_scope():
b.ret(data)
b.ret(f(convert(2.0), convert(10000.0)))
assert_decl_has_type(b.env, 'f', func_type(
['int32', 'float32'], 'float32'))
# TODO(@jroesch): need evaluator or new runtime
# to execute this.
def test_concat():
"""
Program:
def try_concat2(x: Float(3, 2), y: Float(2, 2)) -> Float(5, 2) {
return concat(x, y);
}
"""
ib = IRBuilder()
try_concat2 = ib.global_var('try_concat2')
x = ib.param('x', ty=tensor_type(3, 2))
y = ib.param('y', ty=tensor_type(2, 2))
with ib.decl(try_concat2, x, y):
ib.ret(concat(x, y))
fn_ty = func_type([tensor_type(3, 2), tensor_type(2, 2)], tensor_type(5, 2))
assert_decl_has_type(ib.env, try_concat2, fn_ty)
if __name__ == "__main__":
# test_monomorphic_let()
# test_single_op()
# test_add_op()
# test_add_broadcast_op()
# test_dual_op()
# test_decl()
# test_recursion()
test_concat()
......@@ -18,6 +18,8 @@ TVM_FFI=cython python -m nose -v tests/python/integration || exit -1
TVM_FFI=ctypes python3 -m nose -v tests/python/integration || exit -1
TVM_FFI=cython python -m nose -v tests/python/contrib || exit -1
TVM_FFI=ctypes python3 -m nose -v tests/python/contrib || exit -1
TVM_FFI=cython python -m nose -v tests/python/relay || exit -1
TVM_FFI=ctypes python3 -m nose -v tests/python/relay || exit -1
# Do not enabke OpenGL
# TVM_FFI=cython python -m nose -v tests/webgl || exit -1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment