Commit d05fed22 by 雾雨魔理沙 Committed by ziheng

[Relay] Reference (#2489)

* move

fix test

fix lint

fix test

add more code

fix lint

better type infer ability

* fix build

* address comment
parent 895ef972
......@@ -428,12 +428,78 @@ class TupleGetItemNode : public ExprNode {
TVM_DLL static TupleGetItem make(Expr tuple, int index);
static constexpr const char * _type_key = "relay.TupleGetItem";
static constexpr const char* _type_key = "relay.TupleGetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);
/*! \brief Create a new Reference out of initial value. */
class RefCreate;
class RefCreateNode : public ExprNode {
public:
/*! \brief The initial value of the Reference. */
Expr value;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("value", &value);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static RefCreate make(Expr value);
static constexpr const char* _type_key = "relay.RefCreate";
TVM_DECLARE_NODE_TYPE_INFO(RefCreateNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(RefCreate, RefCreateNode, Expr);
/*! \brief Get value out of Reference. */
class RefRead;
class RefReadNode : public ExprNode {
public:
/*! \brief The Reference Expression. */
Expr ref;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("ref", &ref);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static RefRead make(Expr ref);
static constexpr const char* _type_key = "relay.RefRead";
TVM_DECLARE_NODE_TYPE_INFO(RefReadNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr);
/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
class RefWrite;
class RefWriteNode : public ExprNode {
public:
/*! \brief The Reference Expression. */
Expr ref;
/*! \brief The value to write into. */
Expr value;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("ref", &ref);
v->Visit("value", &value);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static RefWrite make(Expr ref, Expr value);
static constexpr const char* _type_key = "relay.RefWrite";
TVM_DECLARE_NODE_TYPE_INFO(RefWriteNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr);
/*!
* \brief Base class of the temporary expression.
*
......
......@@ -89,6 +89,9 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const OpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
......@@ -108,6 +111,9 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
return vtable;
}
};
......@@ -133,6 +139,9 @@ class ExprVisitor
void VisitExpr_(const IfNode* op) override;
void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;
void VisitExpr_(const RefCreateNode* op) override;
void VisitExpr_(const RefReadNode* op) override;
void VisitExpr_(const RefWriteNode* op) override;
virtual void VisitType(const Type& t);
protected:
......@@ -168,6 +177,9 @@ class ExprMutator
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
Expr VisitExpr_(const RefCreateNode* op) override;
Expr VisitExpr_(const RefReadNode* op) override;
Expr VisitExpr_(const RefWriteNode* op) override;
/*!
* \brief Used to visit the types inside of expressions.
*
......
......@@ -140,6 +140,25 @@ struct TensorValueNode : ValueNode {
RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value);
/*! \brief A reference value. */
class RefValue;
struct RefValueNode : ValueNode {
mutable Value value;
RefValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("value", &value);
}
TVM_DLL static RefValue make(Value val);
static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_NODE_TYPE_INFO(RefValueNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
} // namespace relay
} // namespace tvm
......
......@@ -262,6 +262,33 @@ class TupleTypeNode : public TypeNode {
RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type);
/*!
* \brief The type of reference values.
*/
class RefType;
/*!
* \brief Reference Type in relay.
*/
class RefTypeNode : public TypeNode {
public:
/*! \brief The type of value in the Reference. */
Type value;
RefTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("value", &value);
v->Visit("span", &span);
}
TVM_DLL static RefType make(Type value);
static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_NODE_TYPE_INFO(RefTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(RefType, RefTypeNode, Type);
class TypeReporter;
/*!
......
......@@ -44,6 +44,7 @@ FuncType = ty.FuncType
TypeRelation = ty.TypeRelation
IncompleteType = ty.IncompleteType
scalar_type = ty.scalar_type
RefType = ty.RefType
# Expr
Expr = expr.Expr
......@@ -56,15 +57,18 @@ Call = expr.Call
Let = expr.Let
If = expr.If
TupleGetItem = expr.TupleGetItem
# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
ExprMutator = expr_functor.ExprMutator
RefCreate = expr.RefCreate
RefRead = expr.RefRead
RefWrite = expr.RefWrite
# helper functions
var = expr.var
const = expr.const
bind = expr.bind
# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
ExprMutator = expr_functor.ExprMutator
# Parser
fromtext = parser.fromtext
......@@ -283,6 +283,15 @@ class GraphRuntimeCodegen(ExprFunctor):
def visit_op(self, _):
raise Exception("can not compile op in non-eta expanded form")
def visit_ref_create(self, _):
raise RuntimeError("reference not supported")
def visit_ref_read(self, _):
raise RuntimeError("reference not supported")
def visit_ref_write(self, _):
raise RuntimeError("reference not supported")
def _get_json(self):
"""
Convert the sequence of nodes stored by the compiler into the
......
......@@ -45,6 +45,7 @@ class TupleValue(Value):
def __iter__(self):
return iter(self.fields)
@register_relay_node
class Closure(Value):
"""A closure produced by the interpreter."""
......@@ -79,6 +80,13 @@ class TensorValue(Value):
return str(self.data)
@register_relay_node
class RefValue(Value):
def __init__(self, value):
self.__init_handle_by_constructor__(
_make.RefValue, value)
def _arg_to_ast(arg):
if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(_nd.cpu(0)))
......
......@@ -327,6 +327,46 @@ class TupleGetItem(Expr):
_make.TupleGetItem, tuple_value, index)
@register_relay_node
class RefCreate(Expr):
"""Create a new reference from initial value.
Parameters
----------
value: tvm.relay.Expr
The initial value.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefCreate, value)
@register_relay_node
class RefRead(Expr):
"""Get the value inside the reference.
Parameters
----------
ref: tvm.relay.Expr
The reference.
"""
def __init__(self, ref):
self.__init_handle_by_constructor__(_make.RefRead, ref)
@register_relay_node
class RefWrite(Expr):
"""
Update the value inside the reference.
The whole expression will evaluate to an empty tuple.
Parameters
----------
ref: tvm.relay.Expr
The reference.
value: tvm.relay.Expr
The new value.
"""
def __init__(self, ref, value):
self.__init_handle_by_constructor__(_make.RefWrite, ref, value)
class TempExpr(Expr):
"""Baseclass of all TempExpr.
......
......@@ -41,6 +41,12 @@ class ExprFunctor:
res = self.visit_constant(expr)
elif isinstance(expr, Op):
res = self.visit_op(expr)
elif isinstance(expr, RefCreate):
res = self.visit_ref_create(expr)
elif isinstance(expr, RefRead):
res = self.visit_ref_read(expr)
elif isinstance(expr, RefWrite):
res = self.visit_ref_write(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))
......@@ -81,6 +87,14 @@ class ExprFunctor:
def visit_constant(self, _):
raise NotImplementedError()
def visit_ref_create(self, _):
raise NotImplementedError()
def visit_ref_write(self, _):
raise NotImplementedError()
def visit_ref_read(self, _):
raise NotImplementedError()
class ExprMutator(ExprFunctor):
"""
......@@ -145,8 +159,8 @@ class ExprMutator(ExprFunctor):
def visit_match(self, m):
return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.pattern])
def visit_ref_new(self, r):
return RefNew(self.visit(r.value))
def visit_ref_create(self, r):
return RefCreate(self.visit(r.value))
def visit_ref_write(self, r):
return RefWrite(self.visit(r.ref), self.visit(r.value))
......
......@@ -210,6 +210,19 @@ class TypeRelation(TypeConstraint):
func, args, num_inputs, attrs)
@register_relay_node
class RefType(Type):
"""Reference Type in relay.
Parameters
----------
value: Type
The value type.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefType, value)
def scalar_type(dtype):
"""Creates a scalar type.
......
......@@ -75,6 +75,23 @@ TVM_REGISTER_API("relay._make.TensorValue")
*ret = TensorValueNode::make(data);
});
RefValue RefValueNode::make(Value value) {
NodePtr<RefValueNode> n = make_node<RefValueNode>();
n->value = value;
return RefValue(n);
}
TVM_REGISTER_API("relay._make.RefValue")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefValueNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefValueNode>([](const RefValueNode* node,
tvm::IRPrinter* p) {
p->stream << "RefValueNode(" << node->value << ")";
});
/*!
* \brief A stack frame in the Relay interpreter.
*
......@@ -432,6 +449,31 @@ class Interpreter :
}
}
Value VisitExpr_(const RefWriteNode* op) final {
Value r = Eval(op->ref);
if (const RefValueNode* rv = r.as<RefValueNode>()) {
rv->value = Eval(op->value);
return TupleValueNode::make({});
} else {
LOG(FATAL) << "type error, type system should have caught this";
return Value();
}
}
Value VisitExpr_(const RefCreateNode* op) final {
return RefValueNode::make(Eval(op->value));
}
Value VisitExpr_(const RefReadNode* op) final {
Value r = Eval(op->ref);
if (const RefValueNode* rv = r.as<RefValueNode>()) {
return rv->value;
} else {
LOG(FATAL) << "type error, type system should have caught this";
return Value();
}
}
InterpreterState get_state(Expr e = Expr()) const {
InterpreterStateNode::Stack stack;
for (auto fr : this->stack_.frames) {
......
......@@ -207,6 +207,14 @@ class AlphaEqualHandler:
return false;
}
}
bool VisitType_(const RefTypeNode* lhs, const Type& other) final {
if (const RefTypeNode* rhs = other.as<RefTypeNode>()) {
return TypeEqual(lhs->value, rhs->value);
}
return false;
}
// Expr equal checking.
bool NDArrayEqual(const runtime::NDArray& lhs,
const runtime::NDArray& rhs) {
......@@ -361,6 +369,29 @@ class AlphaEqualHandler:
}
}
bool VisitExpr_(const RefCreateNode* op, const Expr& e2) final {
if (const RefCreateNode* nr = e2.as<RefCreateNode>()) {
return ExprEqual(op->value, nr->value);
} else {
return false;
}
}
bool VisitExpr_(const RefReadNode* op, const Expr& e2) final {
if (const RefReadNode* r = e2.as<RefReadNode>()) {
return ExprEqual(op->ref, r->ref);
} else {
return false;
}
}
bool VisitExpr_(const RefWriteNode* op, const Expr& e2) final {
if (const RefWriteNode* r = e2.as<RefWriteNode>()) {
return ExprEqual(op->ref, r->ref) && ExprEqual(op->value, r->value);
} else {
return false;
}
}
private:
// whether to map open terms.
bool map_free_var_{false};
......
......@@ -271,13 +271,59 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
});
RefCreate RefCreateNode::make(Expr value) {
NodePtr<RefCreateNode> n = make_node<RefCreateNode>();
n->value = std::move(value);
return RefCreate(n);
}
TVM_REGISTER_API("relay._expr.TempExprRealize")
TVM_REGISTER_API("relay._make.RefCreate").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefCreateNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefCreateNode>([](const RefCreateNode* node, tvm::IRPrinter* p) {
p->stream << "RefCreateNode(" << node->value << ")";
});
RefRead RefReadNode::make(Expr ref) {
NodePtr<RefReadNode> n = make_node<RefReadNode>();
n->ref = std::move(ref);
return RefRead(n);
}
TVM_REGISTER_API("relay._make.RefRead")
.set_body([](TVMArgs args, TVMRetValue* ret) {
TempExpr temp = args[0];
*ret = temp->Realize();
*ret = RefReadNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefReadNode>([](const RefReadNode* node, tvm::IRPrinter* p) {
p->stream << "RefReadNode(" << node->ref << ")";
});
RefWrite RefWriteNode::make(Expr ref, Expr value) {
NodePtr<RefWriteNode> n = make_node<RefWriteNode>();
n->ref = std::move(ref);
n->value = std::move(value);
return RefWrite(n);
}
TVM_REGISTER_API("relay._make.RefWrite")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefWriteNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefWriteNode>([](const RefWriteNode* node, tvm::IRPrinter* p) {
p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
});
TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
TempExpr temp = args[0];
*ret = temp->Realize();
});
} // namespace relay
} // namespace tvm
......@@ -157,6 +157,34 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
}
}
Expr ExprMutator::VisitExpr_(const RefCreateNode* op) {
Expr value = this->Mutate(op->value);
if (value.same_as(op->value)) {
return GetRef<Expr>(op);
} else {
return RefCreateNode::make(value);
}
}
Expr ExprMutator::VisitExpr_(const RefReadNode* op) {
Expr ref = this->Mutate(op->ref);
if (ref.same_as(op->ref)) {
return GetRef<Expr>(op);
} else {
return RefReadNode::make(ref);
}
}
Expr ExprMutator::VisitExpr_(const RefWriteNode* op) {
Expr ref = this->Mutate(op->ref);
Expr value = this->Mutate(op->value);
if (ref.same_as(op->ref) && value.same_as(op->value)) {
return GetRef<Expr>(op);
} else {
return RefWriteNode::make(ref, value);
}
}
Type ExprMutator::VisitType(const Type& t) { return t; }
void ExprVisitor::VisitExpr(const Expr& expr) {
......@@ -226,6 +254,19 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
this->VisitExpr(op->tuple);
}
void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) {
this->VisitExpr(op->value);
}
void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) {
this->VisitExpr(op->ref);
}
void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) {
this->VisitExpr(op->ref);
this->VisitExpr(op->value);
}
void ExprVisitor::VisitType(const Type& t) { return; }
// visitor to implement apply
......
......@@ -16,9 +16,9 @@ namespace relay {
// Hash handler for Relay.
class RelayHashHandler:
public AttrsHashHandler,
public TypeFunctor<size_t(const Type&)>,
public ExprFunctor<size_t(const Expr&)> {
public AttrsHashHandler,
public TypeFunctor<size_t(const Type&)>,
public ExprFunctor<size_t(const Expr&)> {
public:
explicit RelayHashHandler() {}
......@@ -175,6 +175,12 @@ class RelayHashHandler:
return hash;
}
size_t VisitType_(const RefTypeNode* rtn) final {
size_t hash = std::hash<std::string>()(RefTypeNode::_type_key);
hash = Combine(hash, TypeHash(rtn->value));
return hash;
}
// Expr hashing.
size_t NDArrayHash(const runtime::NDArray& array) {
size_t hash = std::hash<uint8_t>()(array->dtype.code);
......@@ -280,6 +286,24 @@ class RelayHashHandler:
return hash;
}
size_t VisitExpr_(const RefCreateNode* rn) final {
size_t hash = std::hash<std::string>()(RefCreateNode::_type_key);
hash = Combine(hash, ExprHash(rn->value));
return hash;
}
size_t VisitExpr_(const RefReadNode* rn) final {
size_t hash = std::hash<std::string>()(RefReadNode::_type_key);
hash = Combine(hash, ExprHash(rn->ref));
return hash;
}
size_t VisitExpr_(const RefWriteNode* rn) final {
size_t hash = std::hash<std::string>()(RefWriteNode::_type_key);
hash = Combine(hash, ExprHash(rn->ref));
hash = Combine(hash, ExprHash(rn->value));
return hash;
}
private:
// renaming of NodeRef to indicate two nodes equals to each other
std::unordered_map<NodeRef, size_t, NodeHash, NodeEqual> hash_map_;
......
......@@ -363,6 +363,34 @@ class TextPrinter :
return id;
}
TextValue VisitExpr_(const RefCreateNode* op) final {
TextValue value = GetValue(op->value);
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = " << "RefCreate(" << op->value << ")";
this->PrintEndInst("\n");
return id;
}
TextValue VisitExpr_(const RefReadNode* op) final {
TextValue ref = GetValue(op->ref);
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = " << "RefRead(" << ref << ")";
this->PrintEndInst("\n");
return id;
}
TextValue VisitExpr_(const RefWriteNode* op) final {
TextValue ref = GetValue(op->ref);
TextValue value = GetValue(op->value);
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = " << "RefWrite(" << ref << ", " << value << ")";
this->PrintEndInst("\n");
return id;
}
/*!
* \brief Print the type to os
* \param type The type to be printed.
......@@ -405,6 +433,10 @@ class TextPrinter :
os << "]";
}
void VisitType_(const RefTypeNode* node, std::ostream& os) final {
VisitTypeDefault_(node, os);
}
void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*)
// by default always print as meta-data
os << meta_.GetMetaNode(GetRef<NodeRef>(node));
......
......@@ -164,5 +164,24 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "TupleTypeNode(" << node->fields << ")";
});
RefType RefTypeNode::make(Type value) {
NodePtr<RefTypeNode> n = make_node<RefTypeNode>();
n->value = std::move(value);
return RefType(n);
}
TVM_REGISTER_API("relay._make.RefType")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefTypeNode::make(args[0]);
});
TVM_REGISTER_NODE_TYPE(RefTypeNode);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefTypeNode>([](const RefTypeNode* node,
tvm::IRPrinter* p) {
p->stream << "RefTypeNode(" << node->value << ")";
});
} // namespace relay
} // namespace tvm
......@@ -38,6 +38,10 @@ void TypeVisitor::VisitType_(const TupleTypeNode* op) {
}
}
void TypeVisitor::VisitType_(const RefTypeNode* op) {
this->VisitType(op->value);
}
void TypeVisitor::VisitType_(const TypeRelationNode* op) {
for (const Type& t : op->args) {
this->VisitType(t);
......@@ -119,6 +123,10 @@ Type TypeMutator::VisitType_(const TupleTypeNode* op) {
}
}
Type TypeMutator::VisitType_(const RefTypeNode* op) {
return RefTypeNode::make(this->VisitType(op->value));
}
Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) {
Array<Type> new_args = MutateArray(type_rel->args);
if (new_args.same_as(type_rel->args)) {
......
......@@ -68,7 +68,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitTypeDefault_(const Node* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->type_key();
throw; // unreachable, written to stop compiler warning
......@@ -86,6 +86,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode);
return vtable;
}
};
......@@ -101,6 +102,7 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> {
void VisitType_(const FuncTypeNode* op) override;
void VisitType_(const TupleTypeNode* op) override;
void VisitType_(const TypeRelationNode* op) override;
void VisitType_(const RefTypeNode* op) override;
};
// Mutator that transform a type to another one.
......@@ -112,6 +114,7 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> {
Type VisitType_(const FuncTypeNode* op) override;
Type VisitType_(const TupleTypeNode* op) override;
Type VisitType_(const TypeRelationNode* type_rel) override;
Type VisitType_(const RefTypeNode* op) override;
private:
Array<Type> MutateArray(Array<Type> arr);
......
......@@ -162,6 +162,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
current->extern_ref = true;
}
}
void AddNode(const tvm::Node* key) {
auto it = graph_.node_map.find(key);
CHECK(it != graph_.node_map.end())
......@@ -174,7 +175,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
}
// Post order tree
void VisitExpr_(const FunctionNode* op) {
void VisitExpr_(const FunctionNode* op) final {
for (auto param : op->params) {
this->Update(param, nullptr, kOpaque);
}
......@@ -182,7 +183,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
ExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const ConstantNode* op) {
void VisitExpr_(const ConstantNode* op) final {
this->AddNode(op);
Node* node = graph_.node_map.at(op);
DataType dtype = TVMType2Type(op->data->dtype);
......@@ -202,7 +203,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
}
}
void VisitExpr_(const CallNode* call) {
void VisitExpr_(const CallNode* call) final {
CHECK(graph_.node_map.count(call));
Node* node = graph_.node_map.at(call);
static auto fpattern =
......@@ -232,7 +233,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
this->AddNode(call);
}
void VisitExpr_(const TupleNode* op) {
void VisitExpr_(const TupleNode* op) final {
CHECK(graph_.node_map.count(op));
Node* tuple_node = graph_.node_map.at(op);
tuple_node->pattern = kInjective;
......@@ -247,7 +248,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
this->AddNode(op);
}
void VisitExpr_(const TupleGetItemNode* op) {
void VisitExpr_(const TupleGetItemNode* op) final {
CHECK(graph_.node_map.count(op));
Node* node = graph_.node_map.at(op);
this->Update(op->tuple, node, kOpaque);
......@@ -255,11 +256,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
this->AddNode(op);
}
void VisitExpr_(const VarNode* op) {
void VisitExpr_(const VarNode* op) final {
this->AddNode(op);
}
void VisitExpr_(const LetNode* op) {
void VisitExpr_(const LetNode* op) final {
// do not fuse through let.
this->Update(op->var, nullptr, kOpaque);
this->Update(op->value, nullptr, kOpaque);
......@@ -268,7 +269,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
this->AddNode(op);
}
void VisitExpr_(const IfNode* op) {
void VisitExpr_(const IfNode* op) final {
// do not fuse through if.
this->Update(op->cond, nullptr, kOpaque);
this->Update(op->true_branch, nullptr, kOpaque);
......@@ -276,6 +277,25 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
void VisitExpr_(const RefCreateNode* op) final {
this->Update(op->value, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
void VisitExpr_(const RefReadNode* op) final {
this->Update(op->ref, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
void VisitExpr_(const RefWriteNode* op) final {
this->Update(op->ref, nullptr, kOpaque);
this->Update(op->value, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
};
IndexedForwardGraph IndexedForwardGraph::Create(
......
......@@ -82,6 +82,12 @@ struct KindChecker : TypeVisitor {
valid = valid && IsTypeKind(op->ret_type);
}
void VisitType_(const RefTypeNode* op) override {
// tuples should only contain normal types
this->VisitType(op->value);
valid = valid && IsTypeKind(op->value);
}
void VisitType_(const TypeRelationNode* op) override {
// arguments to type relation should be normal types
for (const Type& t : op->args) {
......
......@@ -431,6 +431,23 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {});
return solver_.Resolve(ret);
}
Type VisitExpr_(const RefCreateNode* op) final {
return RefTypeNode::make(GetType(op->value));
}
Type VisitExpr_(const RefReadNode* op) final {
Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefRead>(op));
return it;
}
Type VisitExpr_(const RefWriteNode* op) final {
Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefWrite>(op));
this->Unify(GetType(op->value), it, GetRef<RefWrite>(op));
return TupleTypeNode::make({});
}
};
class TypeInferencer::Resolver : public ExprMutator {
......@@ -480,6 +497,18 @@ class TypeInferencer::Resolver : public ExprMutator {
return AttachCheckedType(op);
}
Expr VisitExpr_(const RefCreateNode* op) final {
return AttachCheckedType(op);
}
Expr VisitExpr_(const RefReadNode* op) final {
return AttachCheckedType(op);
}
Expr VisitExpr_(const RefWriteNode* op) final {
return AttachCheckedType(op);
}
// attach checked type to the mutated node.
template<typename T>
Expr AttachCheckedType(const T* op) {
......
......@@ -116,7 +116,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
}
// default: unify only if alpha-equal
Type VisitTypeDefault_(const Node* op, const Type& tn) override {
Type VisitTypeDefault_(const Node* op, const Type& tn) final {
NodeRef nr = GetRef<NodeRef>(op);
Type t1 = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>());
if (!AlphaEqual(t1, tn)) {
......@@ -125,7 +125,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
return t1;
}
Type VisitType_(const TupleTypeNode* op, const Type& tn) override {
Type VisitType_(const TupleTypeNode* op, const Type& tn) final {
const auto* ttn = tn.as<TupleTypeNode>();
if (!ttn || op->fields.size() != ttn->fields.size()) {
return Type(nullptr);
......@@ -142,7 +142,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
return TupleTypeNode::make(new_fields);
}
Type VisitType_(const FuncTypeNode* op, const Type& tn) override {
Type VisitType_(const FuncTypeNode* op, const Type& tn) final {
const auto* ftn = tn.as<FuncTypeNode>();
if (!ftn
|| op->arg_types.size() != ftn->arg_types.size()
......@@ -181,6 +181,14 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints);
}
Type VisitType_(const RefTypeNode* op, const Type& tn) final {
const auto* rtn = tn.as<RefTypeNode>();
if (!rtn) {
return Type(nullptr);
}
return RefTypeNode::make(Unify(op->value, rtn->value));
}
private:
TypeSolver* solver_;
};
......
......@@ -110,6 +110,22 @@ def test_loop():
check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod)
def test_ref():
mod = relay.Module()
three_with_ref = relay.GlobalVar('three_with_ref')
i = relay.Var('i')
iv = relay.Var('iv')
u = relay.Var('u')
uv = relay.Var('uv')
body = relay.add(iv, uv)
body = relay.Let(uv, relay.RefRead(i), body)
body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body)
body = relay.Let(iv, relay.RefRead(i), body)
body = relay.Let(i, relay.RefCreate(relay.const(1)), body)
mod[three_with_ref] = relay.Function([], body)
check_eval(three_with_ref, [], 3, mod=mod)
def test_binds():
x = relay.var("x")
y = relay.add(x, x)
......@@ -118,6 +134,7 @@ def test_binds():
res = intrp.evaluate(y, binds={x: xx}).asnumpy()
tvm.testing.assert_allclose(xx + xx, res)
def test_kwargs_params():
x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(1, 10))
......@@ -131,6 +148,7 @@ def test_kwargs_params():
res = intrp.evaluate(f)(x_data, **params).data
tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data)
if __name__ == "__main__":
test_id()
test_add_const()
......@@ -140,3 +158,4 @@ if __name__ == "__main__":
test_loop()
test_binds()
test_kwargs_params()
test_ref()
......@@ -131,6 +131,18 @@ def test_tuple():
relay.TupleType([tp, tp]))
def test_ref():
x = relay.var("x", "float32")
y = relay.var("y", "float32")
r = relay.RefCreate(x)
st = relay.scalar_type("float32")
assert relay.ir_pass.infer_type(r).checked_type == relay.RefType(st)
g = relay.RefRead(r)
assert relay.ir_pass.infer_type(g).checked_type == st
w = relay.RefWrite(r, y)
assert relay.ir_pass.infer_type(w).checked_type == relay.TupleType([])
def test_free_expr():
x = relay.var("x", "float32")
y = relay.add(x, x)
......@@ -187,12 +199,9 @@ if __name__ == "__main__":
test_decl()
test_recursion()
test_tuple()
test_generalized_tuple()
test_incomplete_call()
test_generalized_call()
test_call_with_type_args()
test_free_expr()
test_type_args()
test_self_reference()
test_global_var_recursion()
test_equal()
test_ref()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment