Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
09efdc9d
Commit
09efdc9d
authored
Oct 22, 2018
by
雾雨魔理沙
Committed by
Yizhi Liu
Oct 22, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] Fix format (#1957)
* save * fix format
parent
390acc52
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
61 additions
and
56 deletions
+61
-56
src/relay/ir/base.cc
+3
-3
src/relay/ir/environment.cc
+2
-2
src/relay/ir/expr.cc
+17
-17
src/relay/ir/type.cc
+14
-14
src/relay/pass/alpha_eq.cc
+8
-3
src/relay/pass/kind_check.cc
+3
-3
src/relay/pass/type_infer.cc
+1
-1
src/relay/pass/util.cc
+7
-7
src/relay/pass/well_formed.cc
+6
-6
No files found.
src/relay/ir/base.cc
View file @
09efdc9d
...
@@ -33,7 +33,7 @@ SourceName SourceName::Get(const std::string& name) {
...
@@ -33,7 +33,7 @@ SourceName SourceName::Get(const std::string& name) {
}
}
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
SourceNameNode
>
([](
const
SourceNameNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
.
set_dispatch
<
SourceNameNode
>
([](
const
SourceNameNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"SourceName("
<<
node
->
name
<<
", "
<<
node
<<
")"
;
p
->
stream
<<
"SourceName("
<<
node
->
name
<<
", "
<<
node
<<
")"
;
});
});
...
@@ -54,12 +54,12 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) {
...
@@ -54,12 +54,12 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) {
TVM_REGISTER_NODE_TYPE
(
SpanNode
);
TVM_REGISTER_NODE_TYPE
(
SpanNode
);
TVM_REGISTER_API
(
"relay._make.Span"
)
TVM_REGISTER_API
(
"relay._make.Span"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
SpanNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
]);
*
ret
=
SpanNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
SpanNode
>
([](
const
SpanNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
.
set_dispatch
<
SpanNode
>
([](
const
SpanNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"SpanNode("
<<
node
->
source
<<
", "
<<
node
->
lineno
<<
", "
p
->
stream
<<
"SpanNode("
<<
node
->
source
<<
", "
<<
node
->
lineno
<<
", "
<<
node
->
col_offset
<<
")"
;
<<
node
->
col_offset
<<
")"
;
});
});
...
...
src/relay/ir/environment.cc
View file @
09efdc9d
...
@@ -73,12 +73,12 @@ Function EnvironmentNode::Lookup(const GlobalVar& var) {
...
@@ -73,12 +73,12 @@ Function EnvironmentNode::Lookup(const GlobalVar& var) {
return
(
*
it
).
second
;
return
(
*
it
).
second
;
}
}
Function
EnvironmentNode
::
Lookup
(
const
std
::
string
&
name
)
{
Function
EnvironmentNode
::
Lookup
(
const
std
::
string
&
name
)
{
GlobalVar
id
=
this
->
GetGlobalVar
(
name
);
GlobalVar
id
=
this
->
GetGlobalVar
(
name
);
return
this
->
Lookup
(
id
);
return
this
->
Lookup
(
id
);
}
}
void
EnvironmentNode
::
Update
(
const
Environment
&
env
)
{
void
EnvironmentNode
::
Update
(
const
Environment
&
env
)
{
for
(
auto
pair
:
env
->
functions
)
{
for
(
auto
pair
:
env
->
functions
)
{
this
->
Update
(
pair
.
first
,
pair
.
second
);
this
->
Update
(
pair
.
first
,
pair
.
second
);
}
}
...
...
src/relay/ir/expr.cc
View file @
09efdc9d
...
@@ -20,12 +20,12 @@ Constant ConstantNode::make(runtime::NDArray data) {
...
@@ -20,12 +20,12 @@ Constant ConstantNode::make(runtime::NDArray data) {
TVM_REGISTER_NODE_TYPE
(
ConstantNode
);
TVM_REGISTER_NODE_TYPE
(
ConstantNode
);
TVM_REGISTER_API
(
"relay._make.Constant"
)
TVM_REGISTER_API
(
"relay._make.Constant"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
ConstantNode
::
make
(
args
[
0
]);
*
ret
=
ConstantNode
::
make
(
args
[
0
]);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
ConstantNode
>
([](
const
ConstantNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
.
set_dispatch
<
ConstantNode
>
([](
const
ConstantNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"Constant(TODO)"
;
p
->
stream
<<
"Constant(TODO)"
;
});
});
...
@@ -49,12 +49,12 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
...
@@ -49,12 +49,12 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
TVM_REGISTER_NODE_TYPE
(
TupleNode
);
TVM_REGISTER_NODE_TYPE
(
TupleNode
);
TVM_REGISTER_API
(
"relay._make.Tuple"
)
TVM_REGISTER_API
(
"relay._make.Tuple"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
TupleNode
::
make
(
args
[
0
]);
*
ret
=
TupleNode
::
make
(
args
[
0
]);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
TupleNode
>
([](
const
TupleNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
.
set_dispatch
<
TupleNode
>
([](
const
TupleNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"Tuple("
<<
node
->
fields
<<
")"
;
p
->
stream
<<
"Tuple("
<<
node
->
fields
<<
")"
;
});
});
...
@@ -68,12 +68,12 @@ Var VarNode::make(std::string name_hint, Type type_annotation) {
...
@@ -68,12 +68,12 @@ Var VarNode::make(std::string name_hint, Type type_annotation) {
TVM_REGISTER_NODE_TYPE
(
VarNode
);
TVM_REGISTER_NODE_TYPE
(
VarNode
);
TVM_REGISTER_API
(
"relay._make.Var"
)
TVM_REGISTER_API
(
"relay._make.Var"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
VarNode
::
make
(
args
[
0
],
args
[
1
]);
*
ret
=
VarNode
::
make
(
args
[
0
],
args
[
1
]);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
VarNode
>
([](
const
VarNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
.
set_dispatch
<
VarNode
>
([](
const
VarNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"Var("
<<
node
->
name_hint
;
p
->
stream
<<
"Var("
<<
node
->
name_hint
;
if
(
node
->
type_annotation
.
defined
())
{
if
(
node
->
type_annotation
.
defined
())
{
p
->
stream
<<
", ty="
;
p
->
stream
<<
", ty="
;
...
@@ -91,12 +91,12 @@ GlobalVar GlobalVarNode::make(std::string name_hint) {
...
@@ -91,12 +91,12 @@ GlobalVar GlobalVarNode::make(std::string name_hint) {
TVM_REGISTER_NODE_TYPE
(
GlobalVarNode
);
TVM_REGISTER_NODE_TYPE
(
GlobalVarNode
);
TVM_REGISTER_API
(
"relay._make.GlobalVar"
)
TVM_REGISTER_API
(
"relay._make.GlobalVar"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
GlobalVarNode
::
make
(
args
[
0
]);
*
ret
=
GlobalVarNode
::
make
(
args
[
0
]);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
GlobalVarNode
>
([](
const
GlobalVarNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
.
set_dispatch
<
GlobalVarNode
>
([](
const
GlobalVarNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"GlobalVar("
<<
node
->
name_hint
<<
")"
;
p
->
stream
<<
"GlobalVar("
<<
node
->
name_hint
<<
")"
;
});
});
...
@@ -124,13 +124,13 @@ FuncType FunctionNode::func_type_annotation() const {
...
@@ -124,13 +124,13 @@ FuncType FunctionNode::func_type_annotation() const {
TVM_REGISTER_NODE_TYPE
(
FunctionNode
);
TVM_REGISTER_NODE_TYPE
(
FunctionNode
);
TVM_REGISTER_API
(
"relay._make.Function"
)
TVM_REGISTER_API
(
"relay._make.Function"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
FunctionNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
*
ret
=
FunctionNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
FunctionNode
>
([](
const
FunctionNode
*
node
,
.
set_dispatch
<
FunctionNode
>
([](
const
FunctionNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"FunctionNode("
<<
node
->
params
<<
", "
<<
node
->
ret_type
p
->
stream
<<
"FunctionNode("
<<
node
->
params
<<
", "
<<
node
->
ret_type
<<
", "
<<
node
->
body
<<
", "
<<
node
->
type_params
<<
")"
;
<<
", "
<<
node
->
body
<<
", "
<<
node
->
type_params
<<
")"
;
});
});
...
@@ -148,12 +148,12 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
...
@@ -148,12 +148,12 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
TVM_REGISTER_NODE_TYPE
(
CallNode
);
TVM_REGISTER_NODE_TYPE
(
CallNode
);
TVM_REGISTER_API
(
"relay._make.Call"
)
TVM_REGISTER_API
(
"relay._make.Call"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
CallNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
*
ret
=
CallNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
CallNode
>
([](
const
CallNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
.
set_dispatch
<
CallNode
>
([](
const
CallNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"CallNode("
<<
node
->
op
<<
", "
<<
node
->
args
<<
", "
p
->
stream
<<
"CallNode("
<<
node
->
op
<<
", "
<<
node
->
args
<<
", "
<<
node
->
attrs
<<
", "
<<
node
->
type_args
<<
")"
;
<<
node
->
attrs
<<
", "
<<
node
->
type_args
<<
")"
;
});
});
...
@@ -169,12 +169,12 @@ Let LetNode::make(Var var, Expr value, Expr body) {
...
@@ -169,12 +169,12 @@ Let LetNode::make(Var var, Expr value, Expr body) {
TVM_REGISTER_NODE_TYPE
(
LetNode
);
TVM_REGISTER_NODE_TYPE
(
LetNode
);
TVM_REGISTER_API
(
"relay._make.Let"
)
TVM_REGISTER_API
(
"relay._make.Let"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
LetNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
]);
*
ret
=
LetNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
LetNode
>
([](
const
LetNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
.
set_dispatch
<
LetNode
>
([](
const
LetNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"LetNode("
<<
node
->
var
<<
", "
<<
node
->
value
p
->
stream
<<
"LetNode("
<<
node
->
var
<<
", "
<<
node
->
value
<<
", "
<<
node
->
body
<<
")"
;
<<
", "
<<
node
->
body
<<
")"
;
});
});
...
@@ -189,12 +189,12 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
...
@@ -189,12 +189,12 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
TVM_REGISTER_NODE_TYPE
(
IfNode
);
TVM_REGISTER_NODE_TYPE
(
IfNode
);
TVM_REGISTER_API
(
"relay._make.If"
).
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
TVM_REGISTER_API
(
"relay._make.If"
).
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
IfNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
]);
*
ret
=
IfNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
]);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
IfNode
>
([](
const
IfNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
.
set_dispatch
<
IfNode
>
([](
const
IfNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"IfNode("
<<
node
->
cond
<<
", "
<<
node
->
true_branch
p
->
stream
<<
"IfNode("
<<
node
->
cond
<<
", "
<<
node
->
true_branch
<<
", "
<<
node
->
false_branch
<<
")"
;
<<
", "
<<
node
->
false_branch
<<
")"
;
});
});
...
...
src/relay/ir/type.cc
View file @
09efdc9d
...
@@ -25,14 +25,14 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
...
@@ -25,14 +25,14 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
TVM_REGISTER_NODE_TYPE
(
TensorTypeNode
);
TVM_REGISTER_NODE_TYPE
(
TensorTypeNode
);
TVM_REGISTER_API
(
"relay._make.TensorType"
)
TVM_REGISTER_API
(
"relay._make.TensorType"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Array
<
IndexExpr
>
shape
=
args
[
0
];
Array
<
IndexExpr
>
shape
=
args
[
0
];
*
ret
=
TensorTypeNode
::
make
(
shape
,
args
[
1
]);
*
ret
=
TensorTypeNode
::
make
(
shape
,
args
[
1
]);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
TensorTypeNode
>
([](
const
TensorTypeNode
*
node
,
.
set_dispatch
<
TensorTypeNode
>
([](
const
TensorTypeNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"TensorType("
<<
node
->
shape
<<
", "
<<
node
->
dtype
<<
")"
;
p
->
stream
<<
"TensorType("
<<
node
->
shape
<<
", "
<<
node
->
dtype
<<
")"
;
});
});
...
@@ -46,15 +46,15 @@ TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) {
...
@@ -46,15 +46,15 @@ TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) {
TVM_REGISTER_NODE_TYPE
(
TypeVarNode
);
TVM_REGISTER_NODE_TYPE
(
TypeVarNode
);
TVM_REGISTER_API
(
"relay._make.TypeVar"
)
TVM_REGISTER_API
(
"relay._make.TypeVar"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
int
kind
=
args
[
1
];
int
kind
=
args
[
1
];
*
ret
=
*
ret
=
TypeVarNode
::
make
(
args
[
0
],
static_cast
<
TypeVarNode
::
Kind
>
(
kind
));
TypeVarNode
::
make
(
args
[
0
],
static_cast
<
TypeVarNode
::
Kind
>
(
kind
));
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
TypeVarNode
>
([](
const
TypeVarNode
*
node
,
.
set_dispatch
<
TypeVarNode
>
([](
const
TypeVarNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"TypeVarNode("
<<
node
->
var
->
name_hint
<<
", "
p
->
stream
<<
"TypeVarNode("
<<
node
->
var
->
name_hint
<<
", "
<<
node
->
kind
<<
")"
;
<<
node
->
kind
<<
")"
;
});
});
...
@@ -95,13 +95,13 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
...
@@ -95,13 +95,13 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
TVM_REGISTER_NODE_TYPE
(
FuncTypeNode
);
TVM_REGISTER_NODE_TYPE
(
FuncTypeNode
);
TVM_REGISTER_API
(
"relay._make.FuncType"
)
TVM_REGISTER_API
(
"relay._make.FuncType"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
FuncTypeNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
*
ret
=
FuncTypeNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
FuncTypeNode
>
([](
const
FuncTypeNode
*
node
,
.
set_dispatch
<
FuncTypeNode
>
([](
const
FuncTypeNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"FuncTypeNode("
<<
node
->
type_params
<<
", "
p
->
stream
<<
"FuncTypeNode("
<<
node
->
type_params
<<
", "
<<
node
->
arg_types
<<
", "
<<
node
->
ret_type
<<
", "
<<
node
->
arg_types
<<
", "
<<
node
->
ret_type
<<
", "
<<
node
->
type_constraints
<<
")"
;
<<
node
->
type_constraints
<<
")"
;
...
@@ -122,12 +122,12 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func,
...
@@ -122,12 +122,12 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func,
TVM_REGISTER_NODE_TYPE
(
TypeRelationNode
);
TVM_REGISTER_NODE_TYPE
(
TypeRelationNode
);
TVM_REGISTER_API
(
"relay._make.TypeRelation"
)
TVM_REGISTER_API
(
"relay._make.TypeRelation"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
TypeRelationNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
*
ret
=
TypeRelationNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
]);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
TypeRelationNode
>
([](
const
TypeRelationNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
.
set_dispatch
<
TypeRelationNode
>
([](
const
TypeRelationNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"TypeRelationNode("
p
->
stream
<<
"TypeRelationNode("
<<
node
->
func
->
name
<<
node
->
func
->
name
<<
", "
<<
node
->
args
<<
")"
;
<<
", "
<<
node
->
args
<<
")"
;
...
@@ -142,13 +142,13 @@ TupleType TupleTypeNode::make(Array<Type> fields) {
...
@@ -142,13 +142,13 @@ TupleType TupleTypeNode::make(Array<Type> fields) {
TVM_REGISTER_NODE_TYPE
(
TupleTypeNode
);
TVM_REGISTER_NODE_TYPE
(
TupleTypeNode
);
TVM_REGISTER_API
(
"relay._make.TupleType"
)
TVM_REGISTER_API
(
"relay._make.TupleType"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
TupleTypeNode
::
make
(
args
[
0
]);
*
ret
=
TupleTypeNode
::
make
(
args
[
0
]);
});
});
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR_REGISTER
(
IRPrinter
,
vtable
)
.
set_dispatch
<
TupleTypeNode
>
([](
const
TupleTypeNode
*
node
,
.
set_dispatch
<
TupleTypeNode
>
([](
const
TupleTypeNode
*
node
,
tvm
::
IRPrinter
*
p
)
{
tvm
::
IRPrinter
*
p
)
{
p
->
stream
<<
"TupleTypeNode("
<<
node
->
fields
<<
")"
;
p
->
stream
<<
"TupleTypeNode("
<<
node
->
fields
<<
")"
;
});
});
...
...
src/relay/pass/alpha_eq.cc
View file @
09efdc9d
...
@@ -193,11 +193,13 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
...
@@ -193,11 +193,13 @@ struct TypeAlphaEq : TypeVisitor<const Type&> {
};
};
bool
AlphaEqual
(
const
Type
&
t1
,
const
Type
&
t2
)
{
bool
AlphaEqual
(
const
Type
&
t1
,
const
Type
&
t2
)
{
if
(
t1
.
defined
()
!=
t2
.
defined
())
if
(
t1
.
defined
()
!=
t2
.
defined
())
{
return
false
;
return
false
;
}
if
(
!
t1
.
defined
())
if
(
!
t1
.
defined
())
{
return
true
;
return
true
;
}
TypeAlphaEq
aeq
;
TypeAlphaEq
aeq
;
aeq
.
VisitType
(
t1
,
t2
);
aeq
.
VisitType
(
t1
,
t2
);
...
@@ -273,7 +275,10 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
...
@@ -273,7 +275,10 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
for
(
size_t
i
=
0
;
i
<
func1
->
params
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
func1
->
params
.
size
();
++
i
)
{
MergeVarDecl
(
func1
->
params
[
i
],
func2
->
params
[
i
]);
MergeVarDecl
(
func1
->
params
[
i
],
func2
->
params
[
i
]);
}
}
if
(
!
equal
)
return
;
if
(
!
equal
)
{
return
;
}
for
(
size_t
i
=
0U
;
i
<
func1
->
type_params
.
size
();
i
++
)
{
for
(
size_t
i
=
0U
;
i
<
func1
->
type_params
.
size
();
i
++
)
{
equal
=
equal
&&
AlphaEqual
(
func1
->
type_params
[
i
],
func2
->
type_params
[
i
]);
equal
=
equal
&&
AlphaEqual
(
func1
->
type_params
[
i
],
func2
->
type_params
[
i
]);
...
...
src/relay/pass/kind_check.cc
View file @
09efdc9d
...
@@ -29,11 +29,11 @@ struct KindChecker : TypeVisitor<> {
...
@@ -29,11 +29,11 @@ struct KindChecker : TypeVisitor<> {
// checks if t is an incomplete node of kind k or a type param of kind k
// checks if t is an incomplete node of kind k or a type param of kind k
bool
MatchKind
(
const
Type
&
t
,
Kind
k
)
{
bool
MatchKind
(
const
Type
&
t
,
Kind
k
)
{
if
(
const
IncompleteTypeNode
*
tv
=
t
.
as
<
IncompleteTypeNode
>
())
{
if
(
const
IncompleteTypeNode
*
tv
=
t
.
as
<
IncompleteTypeNode
>
())
{
return
tv
->
kind
==
k
;
return
tv
->
kind
==
k
;
}
}
if
(
const
TypeVarNode
*
tp
=
t
.
as
<
TypeVarNode
>
())
{
if
(
const
TypeVarNode
*
tp
=
t
.
as
<
TypeVarNode
>
())
{
return
tp
->
kind
==
k
;
return
tp
->
kind
==
k
;
}
}
...
@@ -93,7 +93,7 @@ struct KindChecker : TypeVisitor<> {
...
@@ -93,7 +93,7 @@ struct KindChecker : TypeVisitor<> {
}
}
}
}
bool
Check
(
const
Type
&
t
)
{
bool
Check
(
const
Type
&
t
)
{
this
->
VisitType
(
t
);
this
->
VisitType
(
t
);
return
valid
;
return
valid
;
}
}
...
...
src/relay/pass/type_infer.cc
View file @
09efdc9d
...
@@ -379,7 +379,7 @@ class TypeInferencer::Resolver : public ExprMutator {
...
@@ -379,7 +379,7 @@ class TypeInferencer::Resolver : public ExprMutator {
return
new_e
;
return
new_e
;
}
}
Type
VisitType
(
const
Type
&
t
)
final
{
Type
VisitType
(
const
Type
&
t
)
final
{
return
solver_
->
Resolve
(
t
);
return
solver_
->
Resolve
(
t
);
}
}
...
...
src/relay/pass/util.cc
View file @
09efdc9d
...
@@ -14,10 +14,10 @@ namespace relay {
...
@@ -14,10 +14,10 @@ namespace relay {
class
FreeVar
;
class
FreeVar
;
class
FreeTypeVar
:
private
TypeVisitor
<>
{
class
FreeTypeVar
:
private
TypeVisitor
<>
{
std
::
unordered_set
<
TypeVar
,
NodeHash
,
NodeEqual
>
*
free_vars
;
std
::
unordered_set
<
TypeVar
,
NodeHash
,
NodeEqual
>*
free_vars
;
std
::
unordered_set
<
TypeVar
,
NodeHash
,
NodeEqual
>
*
bound_vars
;
std
::
unordered_set
<
TypeVar
,
NodeHash
,
NodeEqual
>*
bound_vars
;
FreeTypeVar
(
std
::
unordered_set
<
TypeVar
,
NodeHash
,
NodeEqual
>
*
free_vars
,
FreeTypeVar
(
std
::
unordered_set
<
TypeVar
,
NodeHash
,
NodeEqual
>*
free_vars
,
std
::
unordered_set
<
TypeVar
,
NodeHash
,
NodeEqual
>
*
bound_vars
)
:
std
::
unordered_set
<
TypeVar
,
NodeHash
,
NodeEqual
>*
bound_vars
)
:
free_vars
(
free_vars
),
bound_vars
(
bound_vars
)
{
}
free_vars
(
free_vars
),
bound_vars
(
bound_vars
)
{
}
void
VisitType_
(
const
TypeVarNode
*
tp
)
final
{
void
VisitType_
(
const
TypeVarNode
*
tp
)
final
{
...
@@ -45,7 +45,7 @@ class FreeTypeVar : private TypeVisitor<> {
...
@@ -45,7 +45,7 @@ class FreeTypeVar : private TypeVisitor<> {
};
};
class
FreeVar
:
public
ExprVisitor
{
class
FreeVar
:
public
ExprVisitor
{
void
VisitExpr_
(
const
VarNode
*
v
)
final
{
void
VisitExpr_
(
const
VarNode
*
v
)
final
{
auto
var
=
GetRef
<
Var
>
(
v
);
auto
var
=
GetRef
<
Var
>
(
v
);
if
(
bound_vars
.
count
(
var
)
==
0
)
{
if
(
bound_vars
.
count
(
var
)
==
0
)
{
free_vars
.
insert
(
var
);
free_vars
.
insert
(
var
);
...
@@ -55,7 +55,7 @@ class FreeVar : public ExprVisitor {
...
@@ -55,7 +55,7 @@ class FreeVar : public ExprVisitor {
}
}
}
}
void
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
void
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
for
(
const
auto
&
tp
:
f
->
type_params
)
{
for
(
const
auto
&
tp
:
f
->
type_params
)
{
bound_types
.
insert
(
tp
);
bound_types
.
insert
(
tp
);
}
}
...
@@ -66,7 +66,7 @@ class FreeVar : public ExprVisitor {
...
@@ -66,7 +66,7 @@ class FreeVar : public ExprVisitor {
VisitType
(
f
->
ret_type
);
VisitType
(
f
->
ret_type
);
}
}
void
VisitExpr_
(
const
LetNode
*
l
)
final
{
void
VisitExpr_
(
const
LetNode
*
l
)
final
{
bound_vars
.
insert
(
l
->
var
);
bound_vars
.
insert
(
l
->
var
);
VisitExpr
(
l
->
value
);
VisitExpr
(
l
->
value
);
VisitExpr
(
l
->
body
);
VisitExpr
(
l
->
body
);
...
...
src/relay/pass/well_formed.cc
View file @
09efdc9d
...
@@ -18,14 +18,14 @@ class WellFormedChecker : private ExprVisitor {
...
@@ -18,14 +18,14 @@ class WellFormedChecker : private ExprVisitor {
std
::
unordered_set
<
Var
,
NodeHash
,
NodeEqual
>
s
;
std
::
unordered_set
<
Var
,
NodeHash
,
NodeEqual
>
s
;
void
Check
(
const
Var
&
v
)
{
void
Check
(
const
Var
&
v
)
{
if
(
s
.
count
(
v
)
!=
0
)
{
if
(
s
.
count
(
v
)
!=
0
)
{
well_formed
=
false
;
well_formed
=
false
;
}
}
s
.
insert
(
v
);
s
.
insert
(
v
);
}
}
void
VisitExpr_
(
const
LetNode
*
l
)
final
{
void
VisitExpr_
(
const
LetNode
*
l
)
final
{
// we do letrec only for FunctionNode,
// we do letrec only for FunctionNode,
// but shadowing let in let binding is likely programming error, and we should forbidden it.
// but shadowing let in let binding is likely programming error, and we should forbidden it.
Check
(
l
->
var
);
Check
(
l
->
var
);
...
@@ -33,21 +33,21 @@ class WellFormedChecker : private ExprVisitor {
...
@@ -33,21 +33,21 @@ class WellFormedChecker : private ExprVisitor {
CheckWellFormed
(
l
->
body
);
CheckWellFormed
(
l
->
body
);
}
}
void
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
void
VisitExpr_
(
const
FunctionNode
*
f
)
final
{
for
(
const
Var
&
param
:
f
->
params
)
{
for
(
const
Var
&
param
:
f
->
params
)
{
Check
(
param
);
Check
(
param
);
}
}
CheckWellFormed
(
f
->
body
);
CheckWellFormed
(
f
->
body
);
}
}
public
:
public
:
bool
CheckWellFormed
(
const
Expr
&
e
)
{
bool
CheckWellFormed
(
const
Expr
&
e
)
{
this
->
VisitExpr
(
e
);
this
->
VisitExpr
(
e
);
return
well_formed
;
return
well_formed
;
}
}
};
};
bool
WellFormed
(
const
Expr
&
e
)
{
bool
WellFormed
(
const
Expr
&
e
)
{
return
WellFormedChecker
().
CheckWellFormed
(
e
);
return
WellFormedChecker
().
CheckWellFormed
(
e
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment