Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
7af48f1a
Unverified
Commit
7af48f1a
authored
Nov 26, 2018
by
Tianqi Chen
Committed by
GitHub
Nov 26, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][IR] Introduce IdNode to preserve var id across rewriting (#2178)
parent
246a38a1
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
80 additions
and
17 deletions
+80
-17
include/tvm/relay/base.h
+28
-0
include/tvm/relay/expr.h
+18
-6
python/tvm/relay/base.py
+7
-0
python/tvm/relay/expr.py
+6
-0
src/relay/backend/compile_engine.cc
+1
-1
src/relay/ir/alpha_equal.cc
+2
-1
src/relay/ir/base.cc
+2
-2
src/relay/ir/expr.cc
+11
-4
src/relay/ir/expr_functor.cc
+1
-1
src/relay/ir/hash.cc
+2
-1
src/relay/ir/text_printer.cc
+1
-1
tests/python/relay/test_type_infer.py
+1
-0
No files found.
include/tvm/relay/base.h
View file @
7af48f1a
...
...
@@ -165,6 +165,34 @@ class RelayNode : public Node {
TVM_DECLARE_BASE_NODE_INFO
(
RelayNode
,
Node
);
};
/*!
* \brief The unique identifier of variables.
*
* Id is like name to the variables,
* except that id is unique for each Var.
*
* \note Do not create Id directly, they are created in Var.
*/
class
IdNode
:
public
Node
{
public
:
/*!
* \brief The name of the variable,
* this only acts as a hint to the user,
* and is not used for equality.
*/
std
::
string
name_hint
;
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"name_hint"
,
&
name_hint
);
}
static
constexpr
const
char
*
_type_key
=
"relay.Id"
;
TVM_DECLARE_NODE_TYPE_INFO
(
IdNode
,
Node
);
};
RELAY_DEFINE_NODE_REF
(
Id
,
IdNode
,
NodeRef
);
struct
Module
;
}
// namespace relay
...
...
include/tvm/relay/expr.h
View file @
7af48f1a
...
...
@@ -124,18 +124,22 @@ RELAY_DEFINE_NODE_REF(Tuple, TupleNode, Expr);
* Its semantics are similar to tvm.Var node used in TVM's low level
* tensor expression language.
*
* \note Each Var is bind only once and is immutable
/
* \note Each Var is bind only once and is immutable
.
*/
class
Var
;
/*! \brief Container for Var */
class
VarNode
:
public
ExprNode
{
public
:
/*!
* \brief The name of the variable,
* this only acts as a hint to the user,
* and is not used for equality.
* \brief The unique identifier of the Var.
*
* vid will be preserved for the same Var during type inference
* and other rewritings, while the VarNode might be recreated
* to attach additional information.
* This property can be used to keep track of parameter Var
* information across passes.
*/
std
::
string
name_hint
;
Id
vid
;
/*!
* \brief type annotaion of the variable.
* This field records user provided type annotation of the Var.
...
...
@@ -143,8 +147,13 @@ class VarNode : public ExprNode {
*/
Type
type_annotation
;
/*! \return The name hint of the variable */
const
std
::
string
&
name_hint
()
const
{
return
vid
->
name_hint
;
}
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"
name_hint"
,
&
name_hint
);
v
->
Visit
(
"
vid"
,
&
vid
);
v
->
Visit
(
"type_annotation"
,
&
type_annotation
);
v
->
Visit
(
"span"
,
&
span
);
v
->
Visit
(
"_checked_type_"
,
&
checked_type_
);
...
...
@@ -153,6 +162,9 @@ class VarNode : public ExprNode {
TVM_DLL
static
Var
make
(
std
::
string
name_hint
,
Type
type_annotation
);
TVM_DLL
static
Var
make
(
Id
vid
,
Type
type_annotation
);
static
constexpr
const
char
*
_type_key
=
"relay.Var"
;
TVM_DECLARE_NODE_TYPE_INFO
(
VarNode
,
ExprNode
);
};
...
...
python/tvm/relay/base.py
View file @
7af48f1a
...
...
@@ -54,3 +54,10 @@ class RelayNode(NodeBase):
class
Span
(
RelayNode
):
def
__init__
(
self
,
source
,
lineno
,
col_offset
):
self
.
__init_handle_by_constructor__
(
_make
.
Span
,
source
,
lineno
,
col_offset
)
@register_relay_node
class
Id
(
NodeBase
):
"""Unique identifier(name) for Var across type checking."""
def
__init__
(
self
):
raise
RuntimeError
(
"Cannot directly construct Id"
)
python/tvm/relay/expr.py
View file @
7af48f1a
...
...
@@ -166,6 +166,12 @@ class Var(Expr):
self
.
__init_handle_by_constructor__
(
_make
.
Var
,
name_hint
,
type_annotation
)
@property
def
name_hint
(
self
):
"""Get name hint of the current var."""
name
=
self
.
vid
.
name_hint
return
name
@register_relay_node
class
GlobalVar
(
Expr
):
...
...
src/relay/backend/compile_engine.cc
View file @
7af48f1a
...
...
@@ -99,7 +99,7 @@ class ScheduleGetter :
}
Array
<
Tensor
>
VisitExpr_
(
const
VarNode
*
op
)
final
{
LOG
(
FATAL
)
<<
"Free variable "
<<
op
->
name_hint
;
LOG
(
FATAL
)
<<
"Free variable "
<<
op
->
name_hint
()
;
return
{};
}
...
...
src/relay/ir/alpha_equal.cc
View file @
7af48f1a
...
...
@@ -240,8 +240,9 @@ class AlphaEqualHandler:
}
bool
VisitExpr_
(
const
VarNode
*
lhs
,
const
Expr
&
other
)
final
{
// This function will only be triggered if we are matching free variables.
if
(
const
VarNode
*
rhs
=
other
.
as
<
VarNode
>
())
{
if
(
lhs
->
name_hint
!=
rhs
->
name_hint
)
return
false
;
if
(
lhs
->
name_hint
()
!=
rhs
->
name_hint
()
)
return
false
;
if
(
!
TypeEqual
(
lhs
->
type_annotation
,
rhs
->
type_annotation
))
return
false
;
return
LeafNodeEqual
(
GetRef
<
NodeRef
>
(
lhs
),
other
);
}
else
{
...
...
src/relay/ir/base.cc
View file @
7af48f1a
...
...
@@ -64,7 +64,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<<
node
->
col_offset
<<
")"
;
});
TVM_REGISTER_NODE_TYPE
(
IdNode
);
}
// namespace relay
}
// namespace tvm
src/relay/ir/expr.cc
View file @
7af48f1a
...
...
@@ -63,23 +63,30 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p
->
stream
<<
"Tuple("
<<
node
->
fields
<<
")"
;
});
Var
VarNode
::
make
(
std
::
string
name_hint
,
Type
type_annotation
)
{
Var
VarNode
::
make
(
Id
vid
,
Type
type_annotation
)
{
NodePtr
<
VarNode
>
n
=
make_node
<
VarNode
>
();
n
->
name_hint
=
std
::
move
(
name_hint
);
n
->
vid
=
std
::
move
(
vid
);
n
->
type_annotation
=
std
::
move
(
type_annotation
);
return
Var
(
n
);
}
Var
VarNode
::
make
(
std
::
string
name_hint
,
Type
type_annotation
)
{
NodePtr
<
IdNode
>
n
=
make_node
<
IdNode
>
();
n
->
name_hint
=
std
::
move
(
name_hint
);
return
VarNode
::
make
(
Id
(
n
),
type_annotation
);
}
TVM_REGISTER_NODE_TYPE
(
VarNode
);
TVM_REGISTER_API
(
"relay._make.Var"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
VarNode
::
make
(
args
[
0
],
args
[
1
]);
*
ret
=
VarNode
::
make
(
args
[
0
]
.
operator
std
::
string
()
,
args
[
1
]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
VarNode
>
([](
const
VarNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"Var("
<<
node
->
name_hint
;
p
->
stream
<<
"Var("
<<
node
->
name_hint
()
;
if
(
node
->
type_annotation
.
defined
())
{
p
->
stream
<<
", ty="
;
p
->
print
(
node
->
type_annotation
);
...
...
src/relay/ir/expr_functor.cc
View file @
7af48f1a
...
...
@@ -30,7 +30,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) {
if
(
op
->
type_annotation
.
defined
())
{
auto
type
=
this
->
VisitType
(
op
->
type_annotation
);
if
(
!
op
->
type_annotation
.
same_as
(
type
))
{
return
VarNode
::
make
(
op
->
name_hint
,
type
);
return
VarNode
::
make
(
op
->
vid
,
type
);
}
}
// default case return self.
...
...
src/relay/ir/hash.cc
View file @
7af48f1a
...
...
@@ -202,7 +202,8 @@ class RelayHashHandler:
}
size_t
VisitExpr_
(
const
VarNode
*
var
)
final
{
size_t
name_hash
=
std
::
hash
<
std
::
string
>
()(
var
->
name_hint
);
// hash free variable
size_t
name_hash
=
std
::
hash
<
const
Node
*>
()(
var
->
vid
.
get
());
return
Combine
(
name_hash
,
TypeHash
(
var
->
type_annotation
));
}
...
...
src/relay/ir/text_printer.cc
View file @
7af48f1a
...
...
@@ -690,7 +690,7 @@ class TextPrinter :
* \return The corresponding name.
*/
TextValue
AllocVarName
(
const
Var
&
var
)
{
std
::
string
name
=
var
->
name_hint
;
std
::
string
name
=
var
->
name_hint
()
;
// always make sure first name is alpha
if
(
name
.
length
()
!=
0
&&
!
std
::
isalpha
(
name
[
0
]))
{
name
=
"%v"
+
name
;
...
...
tests/python/relay/test_type_infer.py
View file @
7af48f1a
...
...
@@ -141,6 +141,7 @@ def test_free_expr():
y
=
relay
.
add
(
x
,
x
)
yy
=
relay
.
ir_pass
.
infer_type
(
y
)
assert
yy
.
checked_type
==
relay
.
scalar_type
(
"float32"
)
assert
x
.
vid
.
same_as
(
yy
.
args
[
0
]
.
vid
)
def
test_type_args
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment