Unverified Commit 7af48f1a by Tianqi Chen Committed by GitHub

[RELAY][IR] Introduce IdNode to preserve var id across rewriting (#2178)

parent 246a38a1
...@@ -165,6 +165,34 @@ class RelayNode : public Node { ...@@ -165,6 +165,34 @@ class RelayNode : public Node {
TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node); TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node);
}; };
/*!
* \brief The unique identifier of variables.
*
* Id is like name to the variables,
* except that id is unique for each Var.
*
* \note Do not create Id directly, they are created in Var.
*/
class IdNode : public Node {
public:
/*!
* \brief The name of the variable,
* this only acts as a hint to the user,
* and is not used for equality.
*/
std::string name_hint;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name_hint", &name_hint);
}
static constexpr const char* _type_key = "relay.Id";
TVM_DECLARE_NODE_TYPE_INFO(IdNode, Node);
};
RELAY_DEFINE_NODE_REF(Id, IdNode, NodeRef);
struct Module; struct Module;
} // namespace relay } // namespace relay
......
...@@ -124,18 +124,22 @@ RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr); ...@@ -124,18 +124,22 @@ RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr);
* Its semantics are similar to tvm.Var node used in TVM's low level * Its semantics are similar to tvm.Var node used in TVM's low level
* tensor expression language. * tensor expression language.
* *
* \note Each Var is bind only once and is immutable/ * \note Each Var is bind only once and is immutable.
*/ */
class Var; class Var;
/*! \brief Container for Var */ /*! \brief Container for Var */
class VarNode : public ExprNode { class VarNode : public ExprNode {
public: public:
/*! /*!
* \brief The name of the variable, * \brief The unique identifier of the Var.
* this only acts as a hint to the user, *
* and is not used for equality. * vid will be preserved for the same Var during type inference
* and other rewritings, while the VarNode might be recreated
* to attach additional information.
* This property can be used to keep track of parameter Var
* information across passes.
*/ */
std::string name_hint; Id vid;
/*! /*!
* \brief type annotaion of the variable. * \brief type annotaion of the variable.
* This field records user provided type annotation of the Var. * This field records user provided type annotation of the Var.
...@@ -143,8 +147,13 @@ class VarNode : public ExprNode { ...@@ -143,8 +147,13 @@ class VarNode : public ExprNode {
*/ */
Type type_annotation; Type type_annotation;
/*! \return The name hint of the variable */
const std::string& name_hint() const {
return vid->name_hint;
}
void VisitAttrs(tvm::AttrVisitor* v) final { void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name_hint", &name_hint); v->Visit("vid", &vid);
v->Visit("type_annotation", &type_annotation); v->Visit("type_annotation", &type_annotation);
v->Visit("span", &span); v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
...@@ -153,6 +162,9 @@ class VarNode : public ExprNode { ...@@ -153,6 +162,9 @@ class VarNode : public ExprNode {
TVM_DLL static Var make(std::string name_hint, TVM_DLL static Var make(std::string name_hint,
Type type_annotation); Type type_annotation);
TVM_DLL static Var make(Id vid,
Type type_annotation);
static constexpr const char* _type_key = "relay.Var"; static constexpr const char* _type_key = "relay.Var";
TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode); TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode);
}; };
......
...@@ -54,3 +54,10 @@ class RelayNode(NodeBase): ...@@ -54,3 +54,10 @@ class RelayNode(NodeBase):
class Span(RelayNode): class Span(RelayNode):
def __init__(self, source, lineno, col_offset): def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset) self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)
@register_relay_node
class Id(NodeBase):
"""Unique identifier(name) for Var across type checking."""
def __init__(self):
raise RuntimeError("Cannot directly construct Id")
...@@ -166,6 +166,12 @@ class Var(Expr): ...@@ -166,6 +166,12 @@ class Var(Expr):
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_make.Var, name_hint, type_annotation) _make.Var, name_hint, type_annotation)
@property
def name_hint(self):
"""Get name hint of the current var."""
name = self.vid.name_hint
return name
@register_relay_node @register_relay_node
class GlobalVar(Expr): class GlobalVar(Expr):
......
...@@ -99,7 +99,7 @@ class ScheduleGetter : ...@@ -99,7 +99,7 @@ class ScheduleGetter :
} }
Array<Tensor> VisitExpr_(const VarNode* op) final { Array<Tensor> VisitExpr_(const VarNode* op) final {
LOG(FATAL) << "Free variable " << op->name_hint; LOG(FATAL) << "Free variable " << op->name_hint();
return {}; return {};
} }
......
...@@ -240,8 +240,9 @@ class AlphaEqualHandler: ...@@ -240,8 +240,9 @@ class AlphaEqualHandler:
} }
bool VisitExpr_(const VarNode* lhs, const Expr& other) final { bool VisitExpr_(const VarNode* lhs, const Expr& other) final {
// This function will only be triggered if we are matching free variables.
if (const VarNode* rhs = other.as<VarNode>()) { if (const VarNode* rhs = other.as<VarNode>()) {
if (lhs->name_hint != rhs->name_hint) return false; if (lhs->name_hint() != rhs->name_hint()) return false;
if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false; if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false;
return LeafNodeEqual(GetRef<NodeRef>(lhs), other); return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
} else { } else {
......
...@@ -64,7 +64,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -64,7 +64,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< node->col_offset << ")"; << node->col_offset << ")";
}); });
TVM_REGISTER_NODE_TYPE(IdNode);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -63,23 +63,30 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -63,23 +63,30 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "Tuple(" << node->fields << ")"; p->stream << "Tuple(" << node->fields << ")";
}); });
Var VarNode::make(std::string name_hint, Type type_annotation) {
Var VarNode::make(Id vid, Type type_annotation) {
NodePtr<VarNode> n = make_node<VarNode>(); NodePtr<VarNode> n = make_node<VarNode>();
n->name_hint = std::move(name_hint); n->vid = std::move(vid);
n->type_annotation = std::move(type_annotation); n->type_annotation = std::move(type_annotation);
return Var(n); return Var(n);
} }
Var VarNode::make(std::string name_hint, Type type_annotation) {
NodePtr<IdNode> n = make_node<IdNode>();
n->name_hint = std::move(name_hint);
return VarNode::make(Id(n), type_annotation);
}
TVM_REGISTER_NODE_TYPE(VarNode); TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_API("relay._make.Var") TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = VarNode::make(args[0], args[1]); *ret = VarNode::make(args[0].operator std::string(), args[1]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<VarNode>([](const VarNode* node, tvm::IRPrinter* p) { .set_dispatch<VarNode>([](const VarNode* node, tvm::IRPrinter* p) {
p->stream << "Var(" << node->name_hint; p->stream << "Var(" << node->name_hint();
if (node->type_annotation.defined()) { if (node->type_annotation.defined()) {
p->stream << ", ty="; p->stream << ", ty=";
p->print(node->type_annotation); p->print(node->type_annotation);
......
...@@ -30,7 +30,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) { ...@@ -30,7 +30,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) {
if (op->type_annotation.defined()) { if (op->type_annotation.defined()) {
auto type = this->VisitType(op->type_annotation); auto type = this->VisitType(op->type_annotation);
if (!op->type_annotation.same_as(type)) { if (!op->type_annotation.same_as(type)) {
return VarNode::make(op->name_hint, type); return VarNode::make(op->vid, type);
} }
} }
// default case return self. // default case return self.
......
...@@ -202,7 +202,8 @@ class RelayHashHandler: ...@@ -202,7 +202,8 @@ class RelayHashHandler:
} }
size_t VisitExpr_(const VarNode* var) final { size_t VisitExpr_(const VarNode* var) final {
size_t name_hash = std::hash<std::string>()(var->name_hint); // hash free variable
size_t name_hash = std::hash<const Node*>()(var->vid.get());
return Combine(name_hash, TypeHash(var->type_annotation)); return Combine(name_hash, TypeHash(var->type_annotation));
} }
......
...@@ -690,7 +690,7 @@ class TextPrinter : ...@@ -690,7 +690,7 @@ class TextPrinter :
* \return The corresponding name. * \return The corresponding name.
*/ */
TextValue AllocVarName(const Var& var) { TextValue AllocVarName(const Var& var) {
std::string name = var->name_hint; std::string name = var->name_hint();
// always make sure first name is alpha // always make sure first name is alpha
if (name.length() != 0 && !std::isalpha(name[0])) { if (name.length() != 0 && !std::isalpha(name[0])) {
name = "%v" + name; name = "%v" + name;
......
...@@ -141,6 +141,7 @@ def test_free_expr(): ...@@ -141,6 +141,7 @@ def test_free_expr():
y = relay.add(x, x) y = relay.add(x, x)
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.scalar_type("float32") assert yy.checked_type == relay.scalar_type("float32")
assert x.vid.same_as(yy.args[0].vid)
def test_type_args(): def test_type_args():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment