Commit d05fed22 by 雾雨魔理沙 Committed by ziheng

[Relay] Reference (#2489)

* move

fix test

fix lint

fix test

add more code

fix lint

better type infer ability

* fix build

* address comment
parent 895ef972
...@@ -428,12 +428,78 @@ class TupleGetItemNode : public ExprNode { ...@@ -428,12 +428,78 @@ class TupleGetItemNode : public ExprNode {
TVM_DLL static TupleGetItem make(Expr tuple, int index); TVM_DLL static TupleGetItem make(Expr tuple, int index);
static constexpr const char * _type_key = "relay.TupleGetItem"; static constexpr const char* _type_key = "relay.TupleGetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode); TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
}; };
RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr); RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);
/*! \brief Create a new Reference out of initial value. */
class RefCreate;
class RefCreateNode : public ExprNode {
public:
/*! \brief The initial value of the Reference. */
Expr value;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("value", &value);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static RefCreate make(Expr value);
static constexpr const char* _type_key = "relay.RefCreate";
TVM_DECLARE_NODE_TYPE_INFO(RefCreateNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(RefCreate, RefCreateNode, Expr);
/*! \brief Get value out of Reference. */
class RefRead;
class RefReadNode : public ExprNode {
public:
/*! \brief The Reference Expression. */
Expr ref;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("ref", &ref);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static RefRead make(Expr ref);
static constexpr const char* _type_key = "relay.RefRead";
TVM_DECLARE_NODE_TYPE_INFO(RefReadNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr);
/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
class RefWrite;
class RefWriteNode : public ExprNode {
public:
/*! \brief The Reference Expression. */
Expr ref;
/*! \brief The value to write into. */
Expr value;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("ref", &ref);
v->Visit("value", &value);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static RefWrite make(Expr ref, Expr value);
static constexpr const char* _type_key = "relay.RefWrite";
TVM_DECLARE_NODE_TYPE_INFO(RefWriteNode, ExprNode);
};
RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr);
/*! /*!
* \brief Base class of the temporary expression. * \brief Base class of the temporary expression.
* *
......
...@@ -89,6 +89,9 @@ class ExprFunctor<R(const Expr& n, Args...)> { ...@@ -89,6 +89,9 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const OpNode* op, virtual R VisitExpr_(const OpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT; Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) { virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key()); throw Error(std::string("Do not have a default for ") + op->type_key());
} }
...@@ -108,6 +111,9 @@ class ExprFunctor<R(const Expr& n, Args...)> { ...@@ -108,6 +111,9 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode); RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
return vtable; return vtable;
} }
}; };
...@@ -133,6 +139,9 @@ class ExprVisitor ...@@ -133,6 +139,9 @@ class ExprVisitor
void VisitExpr_(const IfNode* op) override; void VisitExpr_(const IfNode* op) override;
void VisitExpr_(const OpNode* op) override; void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override; void VisitExpr_(const TupleGetItemNode* op) override;
void VisitExpr_(const RefCreateNode* op) override;
void VisitExpr_(const RefReadNode* op) override;
void VisitExpr_(const RefWriteNode* op) override;
virtual void VisitType(const Type& t); virtual void VisitType(const Type& t);
protected: protected:
...@@ -168,6 +177,9 @@ class ExprMutator ...@@ -168,6 +177,9 @@ class ExprMutator
Expr VisitExpr_(const LetNode* op) override; Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override; Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override; Expr VisitExpr_(const TupleGetItemNode* op) override;
Expr VisitExpr_(const RefCreateNode* op) override;
Expr VisitExpr_(const RefReadNode* op) override;
Expr VisitExpr_(const RefWriteNode* op) override;
/*! /*!
* \brief Used to visit the types inside of expressions. * \brief Used to visit the types inside of expressions.
* *
......
...@@ -140,6 +140,25 @@ struct TensorValueNode : ValueNode { ...@@ -140,6 +140,25 @@ struct TensorValueNode : ValueNode {
RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value); RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value);
/*! \brief A reference value. */
class RefValue;
struct RefValueNode : ValueNode {
mutable Value value;
RefValueNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("value", &value);
}
TVM_DLL static RefValue make(Value val);
static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_NODE_TYPE_INFO(RefValueNode, ValueNode);
};
RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -262,6 +262,33 @@ class TupleTypeNode : public TypeNode { ...@@ -262,6 +262,33 @@ class TupleTypeNode : public TypeNode {
RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type); RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type);
/*!
* \brief The type of reference values.
*/
class RefType;
/*!
* \brief Reference Type in relay.
*/
class RefTypeNode : public TypeNode {
public:
/*! \brief The type of value in the Reference. */
Type value;
RefTypeNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("value", &value);
v->Visit("span", &span);
}
TVM_DLL static RefType make(Type value);
static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_NODE_TYPE_INFO(RefTypeNode, TypeNode);
};
RELAY_DEFINE_NODE_REF(RefType, RefTypeNode, Type);
class TypeReporter; class TypeReporter;
/*! /*!
......
...@@ -44,6 +44,7 @@ FuncType = ty.FuncType ...@@ -44,6 +44,7 @@ FuncType = ty.FuncType
TypeRelation = ty.TypeRelation TypeRelation = ty.TypeRelation
IncompleteType = ty.IncompleteType IncompleteType = ty.IncompleteType
scalar_type = ty.scalar_type scalar_type = ty.scalar_type
RefType = ty.RefType
# Expr # Expr
Expr = expr.Expr Expr = expr.Expr
...@@ -56,15 +57,18 @@ Call = expr.Call ...@@ -56,15 +57,18 @@ Call = expr.Call
Let = expr.Let Let = expr.Let
If = expr.If If = expr.If
TupleGetItem = expr.TupleGetItem TupleGetItem = expr.TupleGetItem
RefCreate = expr.RefCreate
# ExprFunctor RefRead = expr.RefRead
ExprFunctor = expr_functor.ExprFunctor RefWrite = expr.RefWrite
ExprMutator = expr_functor.ExprMutator
# helper functions # helper functions
var = expr.var var = expr.var
const = expr.const const = expr.const
bind = expr.bind bind = expr.bind
# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
ExprMutator = expr_functor.ExprMutator
# Parser # Parser
fromtext = parser.fromtext fromtext = parser.fromtext
...@@ -283,6 +283,15 @@ class GraphRuntimeCodegen(ExprFunctor): ...@@ -283,6 +283,15 @@ class GraphRuntimeCodegen(ExprFunctor):
def visit_op(self, _): def visit_op(self, _):
raise Exception("can not compile op in non-eta expanded form") raise Exception("can not compile op in non-eta expanded form")
def visit_ref_create(self, _):
raise RuntimeError("reference not supported")
def visit_ref_read(self, _):
raise RuntimeError("reference not supported")
def visit_ref_write(self, _):
raise RuntimeError("reference not supported")
def _get_json(self): def _get_json(self):
""" """
Convert the sequence of nodes stored by the compiler into the Convert the sequence of nodes stored by the compiler into the
......
...@@ -45,6 +45,7 @@ class TupleValue(Value): ...@@ -45,6 +45,7 @@ class TupleValue(Value):
def __iter__(self): def __iter__(self):
return iter(self.fields) return iter(self.fields)
@register_relay_node @register_relay_node
class Closure(Value): class Closure(Value):
"""A closure produced by the interpreter.""" """A closure produced by the interpreter."""
...@@ -79,6 +80,13 @@ class TensorValue(Value): ...@@ -79,6 +80,13 @@ class TensorValue(Value):
return str(self.data) return str(self.data)
@register_relay_node
class RefValue(Value):
def __init__(self, value):
self.__init_handle_by_constructor__(
_make.RefValue, value)
def _arg_to_ast(arg): def _arg_to_ast(arg):
if isinstance(arg, TensorValue): if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(_nd.cpu(0))) return Constant(arg.data.copyto(_nd.cpu(0)))
......
...@@ -327,6 +327,46 @@ class TupleGetItem(Expr): ...@@ -327,6 +327,46 @@ class TupleGetItem(Expr):
_make.TupleGetItem, tuple_value, index) _make.TupleGetItem, tuple_value, index)
@register_relay_node
class RefCreate(Expr):
"""Create a new reference from initial value.
Parameters
----------
value: tvm.relay.Expr
The initial value.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefCreate, value)
@register_relay_node
class RefRead(Expr):
"""Get the value inside the reference.
Parameters
----------
ref: tvm.relay.Expr
The reference.
"""
def __init__(self, ref):
self.__init_handle_by_constructor__(_make.RefRead, ref)
@register_relay_node
class RefWrite(Expr):
"""
Update the value inside the reference.
The whole expression will evaluate to an empty tuple.
Parameters
----------
ref: tvm.relay.Expr
The reference.
value: tvm.relay.Expr
The new value.
"""
def __init__(self, ref, value):
self.__init_handle_by_constructor__(_make.RefWrite, ref, value)
class TempExpr(Expr): class TempExpr(Expr):
"""Baseclass of all TempExpr. """Baseclass of all TempExpr.
......
...@@ -41,6 +41,12 @@ class ExprFunctor: ...@@ -41,6 +41,12 @@ class ExprFunctor:
res = self.visit_constant(expr) res = self.visit_constant(expr)
elif isinstance(expr, Op): elif isinstance(expr, Op):
res = self.visit_op(expr) res = self.visit_op(expr)
elif isinstance(expr, RefCreate):
res = self.visit_ref_create(expr)
elif isinstance(expr, RefRead):
res = self.visit_ref_read(expr)
elif isinstance(expr, RefWrite):
res = self.visit_ref_write(expr)
else: else:
raise Exception("warning unhandled case: {0}".format(type(expr))) raise Exception("warning unhandled case: {0}".format(type(expr)))
...@@ -81,6 +87,14 @@ class ExprFunctor: ...@@ -81,6 +87,14 @@ class ExprFunctor:
def visit_constant(self, _): def visit_constant(self, _):
raise NotImplementedError() raise NotImplementedError()
def visit_ref_create(self, _):
raise NotImplementedError()
def visit_ref_write(self, _):
raise NotImplementedError()
def visit_ref_read(self, _):
raise NotImplementedError()
class ExprMutator(ExprFunctor): class ExprMutator(ExprFunctor):
""" """
...@@ -145,8 +159,8 @@ class ExprMutator(ExprFunctor): ...@@ -145,8 +159,8 @@ class ExprMutator(ExprFunctor):
def visit_match(self, m): def visit_match(self, m):
return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.pattern]) return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.pattern])
def visit_ref_new(self, r): def visit_ref_create(self, r):
return RefNew(self.visit(r.value)) return RefCreate(self.visit(r.value))
def visit_ref_write(self, r): def visit_ref_write(self, r):
return RefWrite(self.visit(r.ref), self.visit(r.value)) return RefWrite(self.visit(r.ref), self.visit(r.value))
......
...@@ -210,6 +210,19 @@ class TypeRelation(TypeConstraint): ...@@ -210,6 +210,19 @@ class TypeRelation(TypeConstraint):
func, args, num_inputs, attrs) func, args, num_inputs, attrs)
@register_relay_node
class RefType(Type):
"""Reference Type in relay.
Parameters
----------
value: Type
The value type.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefType, value)
def scalar_type(dtype): def scalar_type(dtype):
"""Creates a scalar type. """Creates a scalar type.
......
...@@ -75,6 +75,23 @@ TVM_REGISTER_API("relay._make.TensorValue") ...@@ -75,6 +75,23 @@ TVM_REGISTER_API("relay._make.TensorValue")
*ret = TensorValueNode::make(data); *ret = TensorValueNode::make(data);
}); });
RefValue RefValueNode::make(Value value) {
NodePtr<RefValueNode> n = make_node<RefValueNode>();
n->value = value;
return RefValue(n);
}
TVM_REGISTER_API("relay._make.RefValue")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefValueNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefValueNode>([](const RefValueNode* node,
tvm::IRPrinter* p) {
p->stream << "RefValueNode(" << node->value << ")";
});
/*! /*!
* \brief A stack frame in the Relay interpreter. * \brief A stack frame in the Relay interpreter.
* *
...@@ -432,6 +449,31 @@ class Interpreter : ...@@ -432,6 +449,31 @@ class Interpreter :
} }
} }
Value VisitExpr_(const RefWriteNode* op) final {
Value r = Eval(op->ref);
if (const RefValueNode* rv = r.as<RefValueNode>()) {
rv->value = Eval(op->value);
return TupleValueNode::make({});
} else {
LOG(FATAL) << "type error, type system should have caught this";
return Value();
}
}
Value VisitExpr_(const RefCreateNode* op) final {
return RefValueNode::make(Eval(op->value));
}
Value VisitExpr_(const RefReadNode* op) final {
Value r = Eval(op->ref);
if (const RefValueNode* rv = r.as<RefValueNode>()) {
return rv->value;
} else {
LOG(FATAL) << "type error, type system should have caught this";
return Value();
}
}
InterpreterState get_state(Expr e = Expr()) const { InterpreterState get_state(Expr e = Expr()) const {
InterpreterStateNode::Stack stack; InterpreterStateNode::Stack stack;
for (auto fr : this->stack_.frames) { for (auto fr : this->stack_.frames) {
......
...@@ -207,6 +207,14 @@ class AlphaEqualHandler: ...@@ -207,6 +207,14 @@ class AlphaEqualHandler:
return false; return false;
} }
} }
bool VisitType_(const RefTypeNode* lhs, const Type& other) final {
if (const RefTypeNode* rhs = other.as<RefTypeNode>()) {
return TypeEqual(lhs->value, rhs->value);
}
return false;
}
// Expr equal checking. // Expr equal checking.
bool NDArrayEqual(const runtime::NDArray& lhs, bool NDArrayEqual(const runtime::NDArray& lhs,
const runtime::NDArray& rhs) { const runtime::NDArray& rhs) {
...@@ -361,6 +369,29 @@ class AlphaEqualHandler: ...@@ -361,6 +369,29 @@ class AlphaEqualHandler:
} }
} }
bool VisitExpr_(const RefCreateNode* op, const Expr& e2) final {
if (const RefCreateNode* nr = e2.as<RefCreateNode>()) {
return ExprEqual(op->value, nr->value);
} else {
return false;
}
}
bool VisitExpr_(const RefReadNode* op, const Expr& e2) final {
if (const RefReadNode* r = e2.as<RefReadNode>()) {
return ExprEqual(op->ref, r->ref);
} else {
return false;
}
}
bool VisitExpr_(const RefWriteNode* op, const Expr& e2) final {
if (const RefWriteNode* r = e2.as<RefWriteNode>()) {
return ExprEqual(op->ref, r->ref) && ExprEqual(op->value, r->value);
} else {
return false;
}
}
private: private:
// whether to map open terms. // whether to map open terms.
bool map_free_var_{false}; bool map_free_var_{false};
......
...@@ -271,6 +271,53 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -271,6 +271,53 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
}); });
RefCreate RefCreateNode::make(Expr value) {
NodePtr<RefCreateNode> n = make_node<RefCreateNode>();
n->value = std::move(value);
return RefCreate(n);
}
TVM_REGISTER_API("relay._make.RefCreate").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefCreateNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefCreateNode>([](const RefCreateNode* node, tvm::IRPrinter* p) {
p->stream << "RefCreateNode(" << node->value << ")";
});
RefRead RefReadNode::make(Expr ref) {
NodePtr<RefReadNode> n = make_node<RefReadNode>();
n->ref = std::move(ref);
return RefRead(n);
}
TVM_REGISTER_API("relay._make.RefRead")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefReadNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefReadNode>([](const RefReadNode* node, tvm::IRPrinter* p) {
p->stream << "RefReadNode(" << node->ref << ")";
});
RefWrite RefWriteNode::make(Expr ref, Expr value) {
NodePtr<RefWriteNode> n = make_node<RefWriteNode>();
n->ref = std::move(ref);
n->value = std::move(value);
return RefWrite(n);
}
TVM_REGISTER_API("relay._make.RefWrite")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefWriteNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefWriteNode>([](const RefWriteNode* node, tvm::IRPrinter* p) {
p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")";
});
TVM_REGISTER_API("relay._expr.TempExprRealize") TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
...@@ -278,6 +325,5 @@ TVM_REGISTER_API("relay._expr.TempExprRealize") ...@@ -278,6 +325,5 @@ TVM_REGISTER_API("relay._expr.TempExprRealize")
*ret = temp->Realize(); *ret = temp->Realize();
}); });
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -157,6 +157,34 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { ...@@ -157,6 +157,34 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
} }
} }
Expr ExprMutator::VisitExpr_(const RefCreateNode* op) {
Expr value = this->Mutate(op->value);
if (value.same_as(op->value)) {
return GetRef<Expr>(op);
} else {
return RefCreateNode::make(value);
}
}
Expr ExprMutator::VisitExpr_(const RefReadNode* op) {
Expr ref = this->Mutate(op->ref);
if (ref.same_as(op->ref)) {
return GetRef<Expr>(op);
} else {
return RefReadNode::make(ref);
}
}
Expr ExprMutator::VisitExpr_(const RefWriteNode* op) {
Expr ref = this->Mutate(op->ref);
Expr value = this->Mutate(op->value);
if (ref.same_as(op->ref) && value.same_as(op->value)) {
return GetRef<Expr>(op);
} else {
return RefWriteNode::make(ref, value);
}
}
Type ExprMutator::VisitType(const Type& t) { return t; } Type ExprMutator::VisitType(const Type& t) { return t; }
void ExprVisitor::VisitExpr(const Expr& expr) { void ExprVisitor::VisitExpr(const Expr& expr) {
...@@ -226,6 +254,19 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { ...@@ -226,6 +254,19 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
this->VisitExpr(op->tuple); this->VisitExpr(op->tuple);
} }
void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) {
this->VisitExpr(op->value);
}
void ExprVisitor::ExprVisitor::VisitExpr_(const RefReadNode* op) {
this->VisitExpr(op->ref);
}
void ExprVisitor::ExprVisitor::VisitExpr_(const RefWriteNode* op) {
this->VisitExpr(op->ref);
this->VisitExpr(op->value);
}
void ExprVisitor::VisitType(const Type& t) { return; } void ExprVisitor::VisitType(const Type& t) { return; }
// visitor to implement apply // visitor to implement apply
......
...@@ -175,6 +175,12 @@ class RelayHashHandler: ...@@ -175,6 +175,12 @@ class RelayHashHandler:
return hash; return hash;
} }
size_t VisitType_(const RefTypeNode* rtn) final {
size_t hash = std::hash<std::string>()(RefTypeNode::_type_key);
hash = Combine(hash, TypeHash(rtn->value));
return hash;
}
// Expr hashing. // Expr hashing.
size_t NDArrayHash(const runtime::NDArray& array) { size_t NDArrayHash(const runtime::NDArray& array) {
size_t hash = std::hash<uint8_t>()(array->dtype.code); size_t hash = std::hash<uint8_t>()(array->dtype.code);
...@@ -280,6 +286,24 @@ class RelayHashHandler: ...@@ -280,6 +286,24 @@ class RelayHashHandler:
return hash; return hash;
} }
size_t VisitExpr_(const RefCreateNode* rn) final {
size_t hash = std::hash<std::string>()(RefCreateNode::_type_key);
hash = Combine(hash, ExprHash(rn->value));
return hash;
}
size_t VisitExpr_(const RefReadNode* rn) final {
size_t hash = std::hash<std::string>()(RefReadNode::_type_key);
hash = Combine(hash, ExprHash(rn->ref));
return hash;
}
size_t VisitExpr_(const RefWriteNode* rn) final {
size_t hash = std::hash<std::string>()(RefWriteNode::_type_key);
hash = Combine(hash, ExprHash(rn->ref));
hash = Combine(hash, ExprHash(rn->value));
return hash;
}
private: private:
// renaming of NodeRef to indicate two nodes equals to each other // renaming of NodeRef to indicate two nodes equals to each other
std::unordered_map<NodeRef, size_t, NodeHash, NodeEqual> hash_map_; std::unordered_map<NodeRef, size_t, NodeHash, NodeEqual> hash_map_;
......
...@@ -363,6 +363,34 @@ class TextPrinter : ...@@ -363,6 +363,34 @@ class TextPrinter :
return id; return id;
} }
TextValue VisitExpr_(const RefCreateNode* op) final {
TextValue value = GetValue(op->value);
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = " << "RefCreate(" << op->value << ")";
this->PrintEndInst("\n");
return id;
}
TextValue VisitExpr_(const RefReadNode* op) final {
TextValue ref = GetValue(op->ref);
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = " << "RefRead(" << ref << ")";
this->PrintEndInst("\n");
return id;
}
TextValue VisitExpr_(const RefWriteNode* op) final {
TextValue ref = GetValue(op->ref);
TextValue value = GetValue(op->value);
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = " << "RefWrite(" << ref << ", " << value << ")";
this->PrintEndInst("\n");
return id;
}
/*! /*!
* \brief Print the type to os * \brief Print the type to os
* \param type The type to be printed. * \param type The type to be printed.
...@@ -405,6 +433,10 @@ class TextPrinter : ...@@ -405,6 +433,10 @@ class TextPrinter :
os << "]"; os << "]";
} }
void VisitType_(const RefTypeNode* node, std::ostream& os) final {
VisitTypeDefault_(node, os);
}
void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*) void VisitTypeDefault_(const Node* node, std::ostream& os) final { // NOLINT(*)
// by default always print as meta-data // by default always print as meta-data
os << meta_.GetMetaNode(GetRef<NodeRef>(node)); os << meta_.GetMetaNode(GetRef<NodeRef>(node));
......
...@@ -164,5 +164,24 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -164,5 +164,24 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "TupleTypeNode(" << node->fields << ")"; p->stream << "TupleTypeNode(" << node->fields << ")";
}); });
RefType RefTypeNode::make(Type value) {
NodePtr<RefTypeNode> n = make_node<RefTypeNode>();
n->value = std::move(value);
return RefType(n);
}
TVM_REGISTER_API("relay._make.RefType")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefTypeNode::make(args[0]);
});
TVM_REGISTER_NODE_TYPE(RefTypeNode);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefTypeNode>([](const RefTypeNode* node,
tvm::IRPrinter* p) {
p->stream << "RefTypeNode(" << node->value << ")";
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -38,6 +38,10 @@ void TypeVisitor::VisitType_(const TupleTypeNode* op) { ...@@ -38,6 +38,10 @@ void TypeVisitor::VisitType_(const TupleTypeNode* op) {
} }
} }
void TypeVisitor::VisitType_(const RefTypeNode* op) {
this->VisitType(op->value);
}
void TypeVisitor::VisitType_(const TypeRelationNode* op) { void TypeVisitor::VisitType_(const TypeRelationNode* op) {
for (const Type& t : op->args) { for (const Type& t : op->args) {
this->VisitType(t); this->VisitType(t);
...@@ -119,6 +123,10 @@ Type TypeMutator::VisitType_(const TupleTypeNode* op) { ...@@ -119,6 +123,10 @@ Type TypeMutator::VisitType_(const TupleTypeNode* op) {
} }
} }
Type TypeMutator::VisitType_(const RefTypeNode* op) {
return RefTypeNode::make(this->VisitType(op->value));
}
Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) { Type TypeMutator::VisitType_(const TypeRelationNode* type_rel) {
Array<Type> new_args = MutateArray(type_rel->args); Array<Type> new_args = MutateArray(type_rel->args);
if (new_args.same_as(type_rel->args)) { if (new_args.same_as(type_rel->args)) {
......
...@@ -68,7 +68,7 @@ class TypeFunctor<R(const Type& n, Args...)> { ...@@ -68,7 +68,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeRelationNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TupleTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const IncompleteTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const RefTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitTypeDefault_(const Node* op, Args...) { virtual R VisitTypeDefault_(const Node* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->type_key(); LOG(FATAL) << "Do not have a default for " << op->type_key();
throw; // unreachable, written to stop compiler warning throw; // unreachable, written to stop compiler warning
...@@ -86,6 +86,7 @@ class TypeFunctor<R(const Type& n, Args...)> { ...@@ -86,6 +86,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode); RELAY_TYPE_FUNCTOR_DISPATCH(TypeRelationNode);
RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode); RELAY_TYPE_FUNCTOR_DISPATCH(IncompleteTypeNode);
RELAY_TYPE_FUNCTOR_DISPATCH(RefTypeNode);
return vtable; return vtable;
} }
}; };
...@@ -101,6 +102,7 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> { ...@@ -101,6 +102,7 @@ class TypeVisitor : public TypeFunctor<void(const Type& n)> {
void VisitType_(const FuncTypeNode* op) override; void VisitType_(const FuncTypeNode* op) override;
void VisitType_(const TupleTypeNode* op) override; void VisitType_(const TupleTypeNode* op) override;
void VisitType_(const TypeRelationNode* op) override; void VisitType_(const TypeRelationNode* op) override;
void VisitType_(const RefTypeNode* op) override;
}; };
// Mutator that transform a type to another one. // Mutator that transform a type to another one.
...@@ -112,6 +114,7 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> { ...@@ -112,6 +114,7 @@ class TypeMutator : public TypeFunctor<Type(const Type& n)> {
Type VisitType_(const FuncTypeNode* op) override; Type VisitType_(const FuncTypeNode* op) override;
Type VisitType_(const TupleTypeNode* op) override; Type VisitType_(const TupleTypeNode* op) override;
Type VisitType_(const TypeRelationNode* type_rel) override; Type VisitType_(const TypeRelationNode* type_rel) override;
Type VisitType_(const RefTypeNode* op) override;
private: private:
Array<Type> MutateArray(Array<Type> arr); Array<Type> MutateArray(Array<Type> arr);
......
...@@ -162,6 +162,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -162,6 +162,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
current->extern_ref = true; current->extern_ref = true;
} }
} }
void AddNode(const tvm::Node* key) { void AddNode(const tvm::Node* key) {
auto it = graph_.node_map.find(key); auto it = graph_.node_map.find(key);
CHECK(it != graph_.node_map.end()) CHECK(it != graph_.node_map.end())
...@@ -174,7 +175,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -174,7 +175,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
} }
// Post order tree // Post order tree
void VisitExpr_(const FunctionNode* op) { void VisitExpr_(const FunctionNode* op) final {
for (auto param : op->params) { for (auto param : op->params) {
this->Update(param, nullptr, kOpaque); this->Update(param, nullptr, kOpaque);
} }
...@@ -182,7 +183,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -182,7 +183,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
ExprVisitor::VisitExpr_(op); ExprVisitor::VisitExpr_(op);
} }
void VisitExpr_(const ConstantNode* op) { void VisitExpr_(const ConstantNode* op) final {
this->AddNode(op); this->AddNode(op);
Node* node = graph_.node_map.at(op); Node* node = graph_.node_map.at(op);
DataType dtype = TVMType2Type(op->data->dtype); DataType dtype = TVMType2Type(op->data->dtype);
...@@ -202,7 +203,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -202,7 +203,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
} }
} }
void VisitExpr_(const CallNode* call) { void VisitExpr_(const CallNode* call) final {
CHECK(graph_.node_map.count(call)); CHECK(graph_.node_map.count(call));
Node* node = graph_.node_map.at(call); Node* node = graph_.node_map.at(call);
static auto fpattern = static auto fpattern =
...@@ -232,7 +233,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -232,7 +233,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
this->AddNode(call); this->AddNode(call);
} }
void VisitExpr_(const TupleNode* op) { void VisitExpr_(const TupleNode* op) final {
CHECK(graph_.node_map.count(op)); CHECK(graph_.node_map.count(op));
Node* tuple_node = graph_.node_map.at(op); Node* tuple_node = graph_.node_map.at(op);
tuple_node->pattern = kInjective; tuple_node->pattern = kInjective;
...@@ -247,7 +248,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -247,7 +248,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
this->AddNode(op); this->AddNode(op);
} }
void VisitExpr_(const TupleGetItemNode* op) { void VisitExpr_(const TupleGetItemNode* op) final {
CHECK(graph_.node_map.count(op)); CHECK(graph_.node_map.count(op));
Node* node = graph_.node_map.at(op); Node* node = graph_.node_map.at(op);
this->Update(op->tuple, node, kOpaque); this->Update(op->tuple, node, kOpaque);
...@@ -255,11 +256,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -255,11 +256,11 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
this->AddNode(op); this->AddNode(op);
} }
void VisitExpr_(const VarNode* op) { void VisitExpr_(const VarNode* op) final {
this->AddNode(op); this->AddNode(op);
} }
void VisitExpr_(const LetNode* op) { void VisitExpr_(const LetNode* op) final {
// do not fuse through let. // do not fuse through let.
this->Update(op->var, nullptr, kOpaque); this->Update(op->var, nullptr, kOpaque);
this->Update(op->value, nullptr, kOpaque); this->Update(op->value, nullptr, kOpaque);
...@@ -268,7 +269,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -268,7 +269,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
this->AddNode(op); this->AddNode(op);
} }
void VisitExpr_(const IfNode* op) { void VisitExpr_(const IfNode* op) final {
// do not fuse through if. // do not fuse through if.
this->Update(op->cond, nullptr, kOpaque); this->Update(op->cond, nullptr, kOpaque);
this->Update(op->true_branch, nullptr, kOpaque); this->Update(op->true_branch, nullptr, kOpaque);
...@@ -276,6 +277,25 @@ class IndexedForwardGraph::Creator : private ExprVisitor { ...@@ -276,6 +277,25 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
ExprVisitor::VisitExpr_(op); ExprVisitor::VisitExpr_(op);
this->AddNode(op); this->AddNode(op);
} }
void VisitExpr_(const RefCreateNode* op) final {
this->Update(op->value, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
void VisitExpr_(const RefReadNode* op) final {
this->Update(op->ref, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
void VisitExpr_(const RefWriteNode* op) final {
this->Update(op->ref, nullptr, kOpaque);
this->Update(op->value, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
}
}; };
IndexedForwardGraph IndexedForwardGraph::Create( IndexedForwardGraph IndexedForwardGraph::Create(
......
...@@ -82,6 +82,12 @@ struct KindChecker : TypeVisitor { ...@@ -82,6 +82,12 @@ struct KindChecker : TypeVisitor {
valid = valid && IsTypeKind(op->ret_type); valid = valid && IsTypeKind(op->ret_type);
} }
void VisitType_(const RefTypeNode* op) override {
// tuples should only contain normal types
this->VisitType(op->value);
valid = valid && IsTypeKind(op->value);
}
void VisitType_(const TypeRelationNode* op) override { void VisitType_(const TypeRelationNode* op) override {
// arguments to type relation should be normal types // arguments to type relation should be normal types
for (const Type& t : op->args) { for (const Type& t : op->args) {
......
...@@ -431,6 +431,23 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -431,6 +431,23 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {}); auto ret = FuncTypeNode::make(arg_types, rtype, f->type_params, {});
return solver_.Resolve(ret); return solver_.Resolve(ret);
} }
Type VisitExpr_(const RefCreateNode* op) final {
return RefTypeNode::make(GetType(op->value));
}
Type VisitExpr_(const RefReadNode* op) final {
Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefRead>(op));
return it;
}
Type VisitExpr_(const RefWriteNode* op) final {
Type it = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
this->Unify(GetType(op->ref), RefTypeNode::make(it), GetRef<RefWrite>(op));
this->Unify(GetType(op->value), it, GetRef<RefWrite>(op));
return TupleTypeNode::make({});
}
}; };
class TypeInferencer::Resolver : public ExprMutator { class TypeInferencer::Resolver : public ExprMutator {
...@@ -480,6 +497,18 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -480,6 +497,18 @@ class TypeInferencer::Resolver : public ExprMutator {
return AttachCheckedType(op); return AttachCheckedType(op);
} }
Expr VisitExpr_(const RefCreateNode* op) final {
return AttachCheckedType(op);
}
Expr VisitExpr_(const RefReadNode* op) final {
return AttachCheckedType(op);
}
Expr VisitExpr_(const RefWriteNode* op) final {
return AttachCheckedType(op);
}
// attach checked type to the mutated node. // attach checked type to the mutated node.
template<typename T> template<typename T>
Expr AttachCheckedType(const T* op) { Expr AttachCheckedType(const T* op) {
......
...@@ -116,7 +116,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -116,7 +116,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
} }
// default: unify only if alpha-equal // default: unify only if alpha-equal
Type VisitTypeDefault_(const Node* op, const Type& tn) override { Type VisitTypeDefault_(const Node* op, const Type& tn) final {
NodeRef nr = GetRef<NodeRef>(op); NodeRef nr = GetRef<NodeRef>(op);
Type t1 = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>()); Type t1 = GetRef<Type>(nr.as_derived<tvm::relay::TypeNode>());
if (!AlphaEqual(t1, tn)) { if (!AlphaEqual(t1, tn)) {
...@@ -125,7 +125,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -125,7 +125,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
return t1; return t1;
} }
Type VisitType_(const TupleTypeNode* op, const Type& tn) override { Type VisitType_(const TupleTypeNode* op, const Type& tn) final {
const auto* ttn = tn.as<TupleTypeNode>(); const auto* ttn = tn.as<TupleTypeNode>();
if (!ttn || op->fields.size() != ttn->fields.size()) { if (!ttn || op->fields.size() != ttn->fields.size()) {
return Type(nullptr); return Type(nullptr);
...@@ -142,7 +142,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -142,7 +142,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
return TupleTypeNode::make(new_fields); return TupleTypeNode::make(new_fields);
} }
Type VisitType_(const FuncTypeNode* op, const Type& tn) override { Type VisitType_(const FuncTypeNode* op, const Type& tn) final {
const auto* ftn = tn.as<FuncTypeNode>(); const auto* ftn = tn.as<FuncTypeNode>();
if (!ftn if (!ftn
|| op->arg_types.size() != ftn->arg_types.size() || op->arg_types.size() != ftn->arg_types.size()
...@@ -181,6 +181,14 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -181,6 +181,14 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints); return FuncTypeNode::make(arg_types, ret_type, ft1->type_params, type_constraints);
} }
Type VisitType_(const RefTypeNode* op, const Type& tn) final {
const auto* rtn = tn.as<RefTypeNode>();
if (!rtn) {
return Type(nullptr);
}
return RefTypeNode::make(Unify(op->value, rtn->value));
}
private: private:
TypeSolver* solver_; TypeSolver* solver_;
}; };
......
...@@ -110,6 +110,22 @@ def test_loop(): ...@@ -110,6 +110,22 @@ def test_loop():
check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod) check_eval(sum_up, [i_data, accum_data], sum(range(1, 11)), mod=mod)
def test_ref():
mod = relay.Module()
three_with_ref = relay.GlobalVar('three_with_ref')
i = relay.Var('i')
iv = relay.Var('iv')
u = relay.Var('u')
uv = relay.Var('uv')
body = relay.add(iv, uv)
body = relay.Let(uv, relay.RefRead(i), body)
body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body)
body = relay.Let(iv, relay.RefRead(i), body)
body = relay.Let(i, relay.RefCreate(relay.const(1)), body)
mod[three_with_ref] = relay.Function([], body)
check_eval(three_with_ref, [], 3, mod=mod)
def test_binds(): def test_binds():
x = relay.var("x") x = relay.var("x")
y = relay.add(x, x) y = relay.add(x, x)
...@@ -118,6 +134,7 @@ def test_binds(): ...@@ -118,6 +134,7 @@ def test_binds():
res = intrp.evaluate(y, binds={x: xx}).asnumpy() res = intrp.evaluate(y, binds={x: xx}).asnumpy()
tvm.testing.assert_allclose(xx + xx, res) tvm.testing.assert_allclose(xx + xx, res)
def test_kwargs_params(): def test_kwargs_params():
x = relay.var("x", shape=(1, 10)) x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(1, 10)) y = relay.var("y", shape=(1, 10))
...@@ -131,6 +148,7 @@ def test_kwargs_params(): ...@@ -131,6 +148,7 @@ def test_kwargs_params():
res = intrp.evaluate(f)(x_data, **params).data res = intrp.evaluate(f)(x_data, **params).data
tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data)
if __name__ == "__main__": if __name__ == "__main__":
test_id() test_id()
test_add_const() test_add_const()
...@@ -140,3 +158,4 @@ if __name__ == "__main__": ...@@ -140,3 +158,4 @@ if __name__ == "__main__":
test_loop() test_loop()
test_binds() test_binds()
test_kwargs_params() test_kwargs_params()
test_ref()
...@@ -131,6 +131,18 @@ def test_tuple(): ...@@ -131,6 +131,18 @@ def test_tuple():
relay.TupleType([tp, tp])) relay.TupleType([tp, tp]))
def test_ref():
x = relay.var("x", "float32")
y = relay.var("y", "float32")
r = relay.RefCreate(x)
st = relay.scalar_type("float32")
assert relay.ir_pass.infer_type(r).checked_type == relay.RefType(st)
g = relay.RefRead(r)
assert relay.ir_pass.infer_type(g).checked_type == st
w = relay.RefWrite(r, y)
assert relay.ir_pass.infer_type(w).checked_type == relay.TupleType([])
def test_free_expr(): def test_free_expr():
x = relay.var("x", "float32") x = relay.var("x", "float32")
y = relay.add(x, x) y = relay.add(x, x)
...@@ -187,12 +199,9 @@ if __name__ == "__main__": ...@@ -187,12 +199,9 @@ if __name__ == "__main__":
test_decl() test_decl()
test_recursion() test_recursion()
test_tuple() test_tuple()
test_generalized_tuple()
test_incomplete_call() test_incomplete_call()
test_generalized_call()
test_call_with_type_args()
test_free_expr() test_free_expr()
test_type_args() test_type_args()
test_self_reference()
test_global_var_recursion() test_global_var_recursion()
test_equal() test_equal()
test_ref()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment