Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
ec0d497c
Unverified
Commit
ec0d497c
authored
Sep 20, 2018
by
Tianqi Chen
Committed by
GitHub
Sep 20, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NODE][RELAY] Move most of the reference related code to node (#1747)
parent
1c2b0b65
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
120 additions
and
105 deletions
+120
-105
include/tvm/node/node.h
+47
-5
include/tvm/relay/base.h
+0
-37
include/tvm/relay/expr.h
+4
-2
include/tvm/relay/expr_functor.h
+12
-12
src/relay/ir/environment.cc
+2
-2
src/relay/ir/expr_functor.cc
+52
-45
src/relay/pass/type_visitor.h
+2
-1
tests/cpp/expr_test.cc
+1
-1
No files found.
include/tvm/node/node.h
View file @
ec0d497c
...
...
@@ -102,10 +102,10 @@ class TVM_DLL Node : public NodeBase {
template
<
typename
T
>
inline
bool
is_type
()
const
;
/*!
* \brief Get a Node
Ref
that holds reference to this Node.
* \return the Node
Ref
* \brief Get a Node
Ptr
that holds reference to this Node.
* \return the Node
Ptr
*/
inline
Node
Ref
GetNodeRef
()
const
;
inline
Node
Ptr
<
Node
>
GetNodePtr
()
const
;
// node ref can see this
friend
class
NodeRef
;
static
constexpr
const
char
*
_type_key
=
"Node"
;
...
...
@@ -177,6 +177,32 @@ class NodeRef {
};
/*!
* \brief Get a reference type from a Node ptr type
*
* It is always important to get a reference type
* if we want to return a value as reference or keep
* the node alive beyond the scope of the function.
*
* \param ptr The node pointer
* \tparam RefType The reference type
* \tparam NodeType The node type
* \return The corresponding RefType
*/
template
<
typename
RefType
,
typename
NodeType
>
inline
RefType
GetRef
(
const
NodeType
*
ptr
);
/*!
* \brief Downcast a base reference type to a more specific type.
*
* \param ref The inptut reference
* \return The corresponding SubRef.
* \tparam SubRef The target specific reference type.
* \tparam BaseRef the current reference type.
*/
template
<
typename
SubRef
,
typename
BaseRef
>
inline
SubRef
Downcast
(
BaseRef
ref
);
/*!
* \brief helper macro to declare type information in a base node.
*/
#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \
...
...
@@ -218,8 +244,24 @@ inline bool Node::derived_from() const {
return
this
->
_DerivedFrom
(
type_id
);
}
inline
NodeRef
Node
::
GetNodeRef
()
const
{
return
NodeRef
(
NodePtr
<
Node
>
(
const_cast
<
Node
*>
(
this
)));
inline
NodePtr
<
Node
>
Node
::
GetNodePtr
()
const
{
return
NodePtr
<
Node
>
(
const_cast
<
Node
*>
(
this
));
}
template
<
typename
RefType
,
typename
NodeType
>
inline
RefType
GetRef
(
const
NodeType
*
ptr
)
{
static_assert
(
std
::
is_base_of
<
typename
RefType
::
ContainerType
,
NodeType
>::
value
,
"Can only cast to the ref of same container type"
);
return
RefType
(
ptr
->
GetNodePtr
());
}
template
<
typename
SubRef
,
typename
BaseRef
>
inline
SubRef
Downcast
(
BaseRef
ref
)
{
CHECK
(
ref
->
template
is_type
<
typename
SubRef
::
ContainerType
>
()
||
ref
->
template
derived_from
<
typename
SubRef
::
ContainerType
>
())
<<
"Downcast from "
<<
ref
->
type_key
()
<<
" to "
<<
SubRef
::
ContainerType
::
_type_key
<<
" failed."
;
return
SubRef
(
std
::
move
(
ref
.
node_
));
}
inline
const
Node
*
NodeRef
::
get
()
const
{
...
...
include/tvm/relay/base.h
View file @
ec0d497c
...
...
@@ -158,43 +158,6 @@ class RelayNode : public Node {
TVM_DECLARE_BASE_NODE_INFO
(
RelayNode
,
Node
);
};
/*!
* \brief Get a reference type from a Node ptr type
*
* It is always important to get a reference type
* if we want to return a value as reference or keep
* the node alive beyond the scope of the function.
*
* \param ptr The node pointer
* \tparam RefType The reference type
* \tparam NodeType The node type
* \return The corresponding RefType
*/
template
<
typename
RefType
,
typename
NodeType
>
RefType
GetRef
(
const
NodeType
*
ptr
)
{
static_assert
(
std
::
is_same
<
typename
RefType
::
ContainerType
,
NodeType
>::
value
,
"Can only cast to the ref of same container type"
);
return
RefType
(
std
::
move
(
ptr
->
GetNodeRef
().
node_
));
}
// TODO(@tqchen, @jroesch): can we move these semantics to HalideIR
template
<
typename
T
>
inline
const
T
*
As
(
const
NodeRef
&
node
)
{
const
Node
*
ptr
=
static_cast
<
const
Node
*>
(
node
.
get
());
if
(
ptr
&&
(
ptr
->
is_type
<
T
>
()
||
ptr
->
derived_from
<
T
>
()))
{
return
static_cast
<
const
T
*>
(
ptr
);
}
return
nullptr
;
}
template
<
typename
SubRef
,
typename
BaseRef
>
SubRef
Downcast
(
BaseRef
ref
)
{
CHECK
(
ref
->
template
is_type
<
typename
SubRef
::
ContainerType
>
())
<<
"Downcast from "
<<
ref
->
type_key
()
<<
" to "
<<
SubRef
::
ContainerType
::
_type_key
<<
" failed."
;
return
SubRef
(
ref
.
node_
);
}
}
// namespace relay
}
// namespace tvm
...
...
include/tvm/relay/expr.h
View file @
ec0d497c
...
...
@@ -65,7 +65,9 @@ class ConstantNode : public ExprNode {
TensorType
tensor_type
()
const
;
/*! \return Whether it is scalar(rank-0 tensor) */
bool
is_scalar
()
const
{
return
data
->
ndim
==
0
;
}
bool
is_scalar
()
const
{
return
data
->
ndim
==
0
;
}
void
VisitAttrs
(
tvm
::
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"data"
,
&
data
);
...
...
@@ -341,7 +343,7 @@ RELAY_DEFINE_NODE_REF(Let, LetNode, Expr);
*
* let x = if (true) { 1 } else { 0 }; // x is 1
* let y = if (false) { 1 } else { 0 }; // y is 0
*
*
* \note This is similar to C's ternary operator.
*/
class
If
;
...
...
include/tvm/relay/expr_functor.h
View file @
ec0d497c
...
...
@@ -139,19 +139,19 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
* the cost of using functional updates.
*/
class
ExprMutator
:
public
::
tvm
::
relay
::
ExprFunctor
<
Expr
(
const
Expr
&
,
const
Expr
&
)
>
{
:
public
::
tvm
::
relay
::
ExprFunctor
<
Expr
(
const
Expr
&
)
>
{
public
:
Expr
Mutate
(
const
Expr
&
expr
);
Expr
VisitExpr_
(
const
VarNode
*
op
,
const
Expr
&
e
)
override
;
Expr
VisitExpr_
(
const
ConstantNode
*
op
,
const
Expr
&
e
)
override
;
Expr
VisitExpr_
(
const
GlobalVarNode
*
op
,
const
Expr
&
e
)
override
;
Expr
VisitExpr_
(
const
OpNode
*
op
,
const
Expr
&
expr
)
override
;
Expr
VisitExpr_
(
const
TupleNode
*
op
,
const
Expr
&
e
)
override
;
Expr
VisitExpr_
(
const
ParamNode
*
op
,
const
Expr
&
e
)
override
;
Expr
VisitExpr_
(
const
FunctionNode
*
op
,
const
Expr
&
e
)
override
;
Expr
VisitExpr_
(
const
CallNode
*
call_node
,
const
Expr
&
e
)
override
;
Expr
VisitExpr_
(
const
LetNode
*
op
,
const
Expr
&
e
)
override
;
Expr
VisitExpr_
(
const
IfNode
*
op
,
const
Expr
&
e
)
override
;
Expr
VisitExpr_
(
const
VarNode
*
op
)
override
;
Expr
VisitExpr_
(
const
ConstantNode
*
op
)
override
;
Expr
VisitExpr_
(
const
GlobalVarNode
*
op
)
override
;
Expr
VisitExpr_
(
const
OpNode
*
op
)
override
;
Expr
VisitExpr_
(
const
TupleNode
*
op
)
override
;
Expr
VisitExpr_
(
const
ParamNode
*
op
)
override
;
Expr
VisitExpr_
(
const
FunctionNode
*
op
)
override
;
Expr
VisitExpr_
(
const
CallNode
*
call_node
)
override
;
Expr
VisitExpr_
(
const
LetNode
*
op
)
override
;
Expr
VisitExpr_
(
const
IfNode
*
op
)
override
;
/*! \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
...
...
@@ -162,7 +162,7 @@ class ExprMutator
private
:
/*! \brief Internal map used for memoization. */
tvm
::
Map
<
Expr
,
Expr
>
memo_
;
std
::
unordered_map
<
Expr
,
Expr
,
NodeHash
,
NodeEqual
>
memo_
;
};
}
// namespace relay
...
...
src/relay/ir/environment.cc
View file @
ec0d497c
...
...
@@ -41,12 +41,12 @@ void EnvironmentNode::Add(const GlobalVar &var,
const
Function
&
func
,
bool
update
)
{
// Type check the item before we add it to the environment.
auto
env
=
relay
::
GetRef
<
Environment
>
(
this
);
auto
env
=
GetRef
<
Environment
>
(
this
);
Expr
checked_expr
=
InferType
(
env
,
var
,
func
);
if
(
const
FunctionNode
*
func_node
=
checked_expr
.
as
<
FunctionNode
>
())
{
auto
checked_func
=
relay
::
GetRef
<
Function
>
(
func_node
);
auto
checked_func
=
GetRef
<
Function
>
(
func_node
);
auto
type
=
checked_func
->
checked_type
();
CHECK
(
IsFullyResolved
(
type
));
...
...
src/relay/ir/expr_functor.cc
View file @
ec0d497c
...
...
@@ -13,33 +13,33 @@ namespace tvm {
namespace
relay
{
Expr
ExprMutator
::
Mutate
(
const
Expr
&
expr
)
{
auto
cached_expr
=
this
->
memo_
.
find
(
expr
);
if
(
cached_expr
!=
this
->
memo_
.
end
())
{
return
(
*
cached_expr
).
second
;
auto
it
=
this
->
memo_
.
find
(
expr
);
if
(
it
!=
this
->
memo_
.
end
())
{
return
it
->
second
;
}
else
{
auto
new_expr
=
this
->
ExprMutator
::
VisitExpr
(
expr
,
expr
);
this
->
memo_
.
Set
(
expr
,
new_expr
)
;
Expr
new_expr
=
ExprMutator
::
VisitExpr
(
expr
);
memo_
[
expr
]
=
new_expr
;
return
new_expr
;
}
}
Expr
ExprMutator
::
VisitExpr_
(
const
VarNode
*
op
,
const
Expr
&
expr
)
{
return
expr
;
Expr
ExprMutator
::
VisitExpr_
(
const
VarNode
*
op
)
{
return
GetRef
<
Expr
>
(
op
)
;
}
Expr
ExprMutator
::
VisitExpr_
(
const
ConstantNode
*
op
,
const
Expr
&
expr
)
{
return
expr
;
Expr
ExprMutator
::
VisitExpr_
(
const
ConstantNode
*
op
)
{
return
GetRef
<
Expr
>
(
op
)
;
}
Expr
ExprMutator
::
VisitExpr_
(
const
GlobalVarNode
*
op
,
const
Expr
&
expr
)
{
return
expr
;
Expr
ExprMutator
::
VisitExpr_
(
const
GlobalVarNode
*
op
)
{
return
GetRef
<
Expr
>
(
op
)
;
}
Expr
ExprMutator
::
VisitExpr_
(
const
OpNode
*
op
,
const
Expr
&
expr
)
{
return
expr
;
Expr
ExprMutator
::
VisitExpr_
(
const
OpNode
*
op
)
{
return
GetRef
<
Expr
>
(
op
)
;
}
Expr
ExprMutator
::
VisitExpr_
(
const
TupleNode
*
op
,
const
Expr
&
e
)
{
Expr
ExprMutator
::
VisitExpr_
(
const
TupleNode
*
op
)
{
tvm
::
Array
<
Expr
>
fields
;
bool
all_fields_unchanged
=
true
;
for
(
auto
field
:
op
->
fields
)
{
...
...
@@ -49,23 +49,23 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op, const Expr& e) {
}
if
(
all_fields_unchanged
)
{
return
e
;
return
GetRef
<
Expr
>
(
op
)
;
}
else
{
return
TupleNode
::
make
(
fields
);
}
}
Expr
ExprMutator
::
VisitExpr_
(
const
ParamNode
*
op
,
const
Expr
&
e
)
{
Expr
ExprMutator
::
VisitExpr_
(
const
ParamNode
*
op
)
{
Var
var
=
Downcast
<
Var
>
(
this
->
Mutate
(
op
->
var
));
auto
type
=
this
->
VisitType
(
op
->
type
);
if
(
var
==
op
->
var
&&
type
==
op
->
type
)
{
return
e
;
if
(
op
->
var
.
same_as
(
var
)
&&
op
->
type
.
same_as
(
type
)
)
{
return
GetRef
<
Expr
>
(
op
)
;
}
else
{
return
ParamNode
::
make
(
var
,
type
);
}
}
Expr
ExprMutator
::
VisitExpr_
(
const
FunctionNode
*
op
,
const
Expr
&
e
)
{
Expr
ExprMutator
::
VisitExpr_
(
const
FunctionNode
*
op
)
{
tvm
::
Array
<
TypeParam
>
ty_params
;
bool
all_ty_params_changed
=
true
;
...
...
@@ -86,74 +86,82 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op, const Expr& e) {
auto
ret_type
=
this
->
VisitType
(
op
->
ret_type
);
auto
body
=
this
->
Mutate
(
op
->
body
);
if
(
ty_params
.
same_as
(
op
->
type_params
)
&&
params
.
same_as
(
op
->
params
)
&&
ret_type
.
same_as
(
op
->
ret_type
)
&&
body
.
same_as
(
op
->
body
))
{
return
e
;
if
(
ty_params
.
same_as
(
op
->
type_params
)
&&
params
.
same_as
(
op
->
params
)
&&
ret_type
.
same_as
(
op
->
ret_type
)
&&
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
Expr
>
(
op
);
}
else
{
return
FunctionNode
::
make
(
params
,
ret_type
,
body
,
ty_params
);
}
}
Expr
ExprMutator
::
VisitExpr_
(
const
CallNode
*
call_node
,
const
Expr
&
e
)
{
auto
op
=
this
->
Mutate
(
call_node
->
op
);
Expr
ExprMutator
::
VisitExpr_
(
const
CallNode
*
call_node
)
{
auto
new_op
=
this
->
Mutate
(
call_node
->
op
);
bool
unchanged
=
call_node
->
op
.
same_as
(
new_op
);
tvm
::
Array
<
Type
>
ty_args
;
bool
all_ty_args_unchanged
=
true
;
for
(
auto
ty_arg
:
call_node
->
type_args
)
{
auto
new_ty_arg
=
this
->
VisitType
(
ty_arg
);
ty_args
.
push_back
(
new_ty_arg
);
all_ty_args_
unchanged
&=
new_ty_arg
.
same_as
(
ty_arg
);
unchanged
&=
new_ty_arg
.
same_as
(
ty_arg
);
}
tvm
::
Array
<
Expr
>
call_args
;
bool
all_args_unchanged
=
true
;
for
(
auto
arg
:
call_node
->
args
)
{
auto
new_arg
=
this
->
Mutate
(
arg
);
call_args
.
push_back
(
new_arg
);
all_args_
unchanged
&=
new_arg
.
same_as
(
arg
);
unchanged
&=
new_arg
.
same_as
(
arg
);
}
if
(
all_ty_args_unchanged
&&
all_args_unchanged
&&
call_node
->
op
.
same_as
(
op
))
{
return
e
;
if
(
unchanged
)
{
return
GetRef
<
Expr
>
(
call_node
);
}
else
{
return
CallNode
::
make
(
op
,
call_args
,
call_node
->
attrs
,
ty_args
);
return
CallNode
::
make
(
new_
op
,
call_args
,
call_node
->
attrs
,
ty_args
);
}
}
Expr
ExprMutator
::
VisitExpr_
(
const
LetNode
*
op
,
const
Expr
&
e
)
{
Expr
ExprMutator
::
VisitExpr_
(
const
LetNode
*
op
)
{
Var
var
=
Downcast
<
Var
>
(
this
->
Mutate
(
op
->
var
));
auto
type
=
this
->
VisitType
(
op
->
value_type
);
auto
value
=
this
->
Mutate
(
op
->
value
);
auto
body
=
this
->
Mutate
(
op
->
body
);
if
(
var
.
same_as
(
op
->
var
)
&&
type
.
same_as
(
op
->
value_type
)
&&
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
e
;
if
(
var
.
same_as
(
op
->
var
)
&&
type
.
same_as
(
op
->
value_type
)
&&
value
.
same_as
(
op
->
value
)
&&
body
.
same_as
(
op
->
body
))
{
return
GetRef
<
Expr
>
(
op
);
}
else
{
return
LetNode
::
make
(
var
,
value
,
body
,
type
);
}
}
Expr
ExprMutator
::
VisitExpr_
(
const
IfNode
*
op
,
const
Expr
&
e
)
{
Expr
ExprMutator
::
VisitExpr_
(
const
IfNode
*
op
)
{
auto
guard
=
this
->
Mutate
(
op
->
cond
);
auto
true_b
=
this
->
Mutate
(
op
->
true_branch
);
auto
false_b
=
this
->
Mutate
(
op
->
false_branch
);
if
(
op
->
cond
==
guard
&&
true_b
==
op
->
true_branch
&&
false_b
==
op
->
false_branch
)
{
return
e
;
if
(
op
->
cond
.
same_as
(
guard
)
&&
op
->
true_branch
.
same_as
(
true_b
)
&&
op
->
false_branch
.
same_as
(
false_b
))
{
return
GetRef
<
Expr
>
(
op
);;
}
else
{
return
IfNode
::
make
(
guard
,
true_b
,
false_b
);
}
}
Type
ExprMutator
::
VisitType
(
const
Type
&
t
)
{
return
t
;
}
Type
ExprMutator
::
VisitType
(
const
Type
&
t
)
{
return
t
;
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
VarNode
*
op
)
{
return
;
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
VarNode
*
op
)
{
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
GlobalVarNode
*
op
)
{
return
;
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
GlobalVarNode
*
op
)
{
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
ConstantNode
*
op
)
{
return
;
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
ConstantNode
*
op
)
{
}
void
ExprVisitor
::
ExprVisitor
::
VisitExpr_
(
const
TupleNode
*
op
)
{
for
(
auto
field
:
op
->
fields
)
{
...
...
@@ -202,4 +210,3 @@ void ExprVisitor::VisitType(const Type& t) { return; }
}
// namespace relay
}
// namespace tvm
src/relay/pass/type_visitor.h
View file @
ec0d497c
...
...
@@ -78,7 +78,8 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> {
Array
<
TypeConstraint
>
type_constraints
;
for
(
auto
type_cs
:
op
->
type_constraints
)
{
auto
new_type_cs
=
VisitType
(
type_cs
);
if
(
const
TypeConstraintNode
*
tin
=
As
<
TypeConstraintNode
>
(
new_type_cs
))
{
if
(
const
TypeConstraintNode
*
tin
=
new_type_cs
.
as_derived
<
TypeConstraintNode
>
())
{
type_constraints
.
push_back
(
GetRef
<
TypeConstraint
>
(
tin
));
}
else
{
CHECK
(
false
)
<<
new_type_cs
<<
std
::
endl
;
...
...
tests/cpp/expr_test.cc
View file @
ec0d497c
...
...
@@ -20,7 +20,7 @@ TEST(ExprNodeRef, Basic) {
Var
x
(
"x"
);
Expr
z
=
max
(
x
+
1
+
2
,
100
);
const
ir
::
Max
*
op
=
z
.
as
<
ir
::
Max
>
();
CHECK
(
op
->
GetNodeRef
(
).
same_as
(
z
));
CHECK
(
NodeRef
(
op
->
GetNodePtr
()
).
same_as
(
z
));
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment