Unverified Commit 7af48f1a by Tianqi Chen Committed by GitHub

[RELAY][IR] Introduce IdNode to preserve var id across rewriting (#2178)

parent 246a38a1
......@@ -165,6 +165,34 @@ class RelayNode : public Node {
TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node);
};
/*!
* \brief The unique identifier of variables.
*
* Id is like name to the variables,
* except that id is unique for each Var.
*
* \note Do not create Id directly, they are created in Var.
*/
class IdNode : public Node {
public:
/*!
* \brief The name of the variable,
* this only acts as a hint to the user,
* and is not used for equality.
*/
std::string name_hint;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name_hint", &name_hint);
}
static constexpr const char* _type_key = "relay.Id";
TVM_DECLARE_NODE_TYPE_INFO(IdNode, Node);
};
RELAY_DEFINE_NODE_REF(Id, IdNode, NodeRef);
struct Module;
} // namespace relay
......
......@@ -124,18 +124,22 @@ RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr);
* Its semantics are similar to tvm.Var node used in TVM's low level
* tensor expression language.
*
* \note Each Var is bind only once and is immutable/
* \note Each Var is bind only once and is immutable.
*/
class Var;
/*! \brief Container for Var */
class VarNode : public ExprNode {
public:
/*!
* \brief The name of the variable,
* this only acts as a hint to the user,
* and is not used for equality.
* \brief The unique identifier of the Var.
*
* vid will be preserved for the same Var during type inference
* and other rewritings, while the VarNode might be recreated
* to attach additional information.
* This property can be used to keep track of parameter Var
* information across passes.
*/
std::string name_hint;
Id vid;
/*!
* \brief type annotaion of the variable.
* This field records user provided type annotation of the Var.
......@@ -143,8 +147,13 @@ class VarNode : public ExprNode {
*/
Type type_annotation;
/*! \return The name hint of the variable */
const std::string& name_hint() const {
return vid->name_hint;
}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name_hint", &name_hint);
v->Visit("vid", &vid);
v->Visit("type_annotation", &type_annotation);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
......@@ -153,6 +162,9 @@ class VarNode : public ExprNode {
TVM_DLL static Var make(std::string name_hint,
Type type_annotation);
TVM_DLL static Var make(Id vid,
Type type_annotation);
static constexpr const char* _type_key = "relay.Var";
TVM_DECLARE_NODE_TYPE_INFO(VarNode, ExprNode);
};
......
......@@ -54,3 +54,10 @@ class RelayNode(NodeBase):
class Span(RelayNode):
def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)
@register_relay_node
class Id(NodeBase):
"""Unique identifier(name) for Var across type checking."""
def __init__(self):
raise RuntimeError("Cannot directly construct Id")
......@@ -166,6 +166,12 @@ class Var(Expr):
self.__init_handle_by_constructor__(
_make.Var, name_hint, type_annotation)
@property
def name_hint(self):
"""Get name hint of the current var."""
name = self.vid.name_hint
return name
@register_relay_node
class GlobalVar(Expr):
......
......@@ -99,7 +99,7 @@ class ScheduleGetter :
}
Array<Tensor> VisitExpr_(const VarNode* op) final {
LOG(FATAL) << "Free variable " << op->name_hint;
LOG(FATAL) << "Free variable " << op->name_hint();
return {};
}
......
......@@ -240,8 +240,9 @@ class AlphaEqualHandler:
}
bool VisitExpr_(const VarNode* lhs, const Expr& other) final {
// This function will only be triggered if we are matching free variables.
if (const VarNode* rhs = other.as<VarNode>()) {
if (lhs->name_hint != rhs->name_hint) return false;
if (lhs->name_hint() != rhs->name_hint()) return false;
if (!TypeEqual(lhs->type_annotation, rhs->type_annotation)) return false;
return LeafNodeEqual(GetRef<NodeRef>(lhs), other);
} else {
......
......@@ -64,7 +64,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< node->col_offset << ")";
});
TVM_REGISTER_NODE_TYPE(IdNode);
} // namespace relay
} // namespace tvm
......@@ -63,23 +63,30 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "Tuple(" << node->fields << ")";
});
Var VarNode::make(std::string name_hint, Type type_annotation) {
Var VarNode::make(Id vid, Type type_annotation) {
NodePtr<VarNode> n = make_node<VarNode>();
n->name_hint = std::move(name_hint);
n->vid = std::move(vid);
n->type_annotation = std::move(type_annotation);
return Var(n);
}
Var VarNode::make(std::string name_hint, Type type_annotation) {
NodePtr<IdNode> n = make_node<IdNode>();
n->name_hint = std::move(name_hint);
return VarNode::make(Id(n), type_annotation);
}
TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = VarNode::make(args[0], args[1]);
*ret = VarNode::make(args[0].operator std::string(), args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<VarNode>([](const VarNode* node, tvm::IRPrinter* p) {
p->stream << "Var(" << node->name_hint;
p->stream << "Var(" << node->name_hint();
if (node->type_annotation.defined()) {
p->stream << ", ty=";
p->print(node->type_annotation);
......
......@@ -30,7 +30,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) {
if (op->type_annotation.defined()) {
auto type = this->VisitType(op->type_annotation);
if (!op->type_annotation.same_as(type)) {
return VarNode::make(op->name_hint, type);
return VarNode::make(op->vid, type);
}
}
// default case return self.
......
......@@ -202,7 +202,8 @@ class RelayHashHandler:
}
size_t VisitExpr_(const VarNode* var) final {
size_t name_hash = std::hash<std::string>()(var->name_hint);
// hash free variable
size_t name_hash = std::hash<const Node*>()(var->vid.get());
return Combine(name_hash, TypeHash(var->type_annotation));
}
......
......@@ -690,7 +690,7 @@ class TextPrinter :
* \return The corresponding name.
*/
TextValue AllocVarName(const Var& var) {
std::string name = var->name_hint;
std::string name = var->name_hint();
// always make sure first name is alpha
if (name.length() != 0 && !std::isalpha(name[0])) {
name = "%v" + name;
......
......@@ -141,6 +141,7 @@ def test_free_expr():
y = relay.add(x, x)
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.scalar_type("float32")
assert x.vid.same_as(yy.args[0].vid)
def test_type_args():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment