Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6027412b
Unverified
Commit
6027412b
authored
Mar 15, 2020
by
Tianqi Chen
Committed by
GitHub
Mar 15, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[IR] Update the type_keys to reflect the code-org (#5074)
parent
7c5ff508
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
255 additions
and
118 deletions
+255
-118
include/tvm/ir/expr.h
+1
-1
include/tvm/ir/module.h
+1
-1
include/tvm/ir/span.h
+2
-2
include/tvm/ir/transform.h
+3
-3
include/tvm/ir/type.h
+10
-8
include/tvm/ir/type_relation.h
+3
-3
python/tvm/ir/__init__.py
+1
-1
python/tvm/ir/base.py
+2
-2
python/tvm/ir/expr.py
+1
-1
python/tvm/ir/json_compact.py
+24
-0
python/tvm/ir/module.py
+1
-1
python/tvm/ir/transform.py
+5
-5
python/tvm/ir/type.py
+19
-6
python/tvm/ir/type_relation.py
+2
-1
src/ir/transform.cc
+2
-2
tests/python/relay/test_ir_nodes.py
+0
-78
tests/python/relay/test_json_compact.py
+70
-3
tests/python/unittest/test_ir_type.py
+108
-0
No files found.
include/tvm/ir/expr.h
View file @
6027412b
...
@@ -196,7 +196,7 @@ class GlobalVarNode : public RelayExprNode {
...
@@ -196,7 +196,7 @@ class GlobalVarNode : public RelayExprNode {
v
->
Visit
(
"_checked_type_"
,
&
checked_type_
);
v
->
Visit
(
"_checked_type_"
,
&
checked_type_
);
}
}
static
constexpr
const
char
*
_type_key
=
"
relay.
GlobalVar"
;
static
constexpr
const
char
*
_type_key
=
"GlobalVar"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
GlobalVarNode
,
RelayExprNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
GlobalVarNode
,
RelayExprNode
);
};
};
...
...
include/tvm/ir/module.h
View file @
6027412b
...
@@ -226,7 +226,7 @@ class IRModuleNode : public Object {
...
@@ -226,7 +226,7 @@ class IRModuleNode : public Object {
*/
*/
TVM_DLL
std
::
unordered_set
<
std
::
string
>
Imports
()
const
;
TVM_DLL
std
::
unordered_set
<
std
::
string
>
Imports
()
const
;
static
constexpr
const
char
*
_type_key
=
"
relay.
Module"
;
static
constexpr
const
char
*
_type_key
=
"
IR
Module"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
IRModuleNode
,
Object
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
IRModuleNode
,
Object
);
private
:
private
:
...
...
include/tvm/ir/span.h
View file @
6027412b
...
@@ -44,7 +44,7 @@ class SourceNameNode : public Object {
...
@@ -44,7 +44,7 @@ class SourceNameNode : public Object {
// override attr visitor
// override attr visitor
void
VisitAttrs
(
AttrVisitor
*
v
)
{
v
->
Visit
(
"name"
,
&
name
);
}
void
VisitAttrs
(
AttrVisitor
*
v
)
{
v
->
Visit
(
"name"
,
&
name
);
}
static
constexpr
const
char
*
_type_key
=
"
relay.
SourceName"
;
static
constexpr
const
char
*
_type_key
=
"SourceName"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
SourceNameNode
,
Object
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
SourceNameNode
,
Object
);
};
};
...
@@ -89,7 +89,7 @@ class SpanNode : public Object {
...
@@ -89,7 +89,7 @@ class SpanNode : public Object {
TVM_DLL
static
Span
make
(
SourceName
source
,
int
lineno
,
int
col_offset
);
TVM_DLL
static
Span
make
(
SourceName
source
,
int
lineno
,
int
col_offset
);
static
constexpr
const
char
*
_type_key
=
"
relay.
Span"
;
static
constexpr
const
char
*
_type_key
=
"Span"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
SpanNode
,
Object
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
SpanNode
,
Object
);
};
};
...
...
include/tvm/ir/transform.h
View file @
6027412b
...
@@ -110,7 +110,7 @@ class PassContextNode : public Object {
...
@@ -110,7 +110,7 @@ class PassContextNode : public Object {
v
->
Visit
(
"disabled_pass"
,
&
disabled_pass
);
v
->
Visit
(
"disabled_pass"
,
&
disabled_pass
);
}
}
static
constexpr
const
char
*
_type_key
=
"
relay
.PassContext"
;
static
constexpr
const
char
*
_type_key
=
"
transform
.PassContext"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
PassContextNode
,
Object
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
PassContextNode
,
Object
);
};
};
...
@@ -206,7 +206,7 @@ class PassInfoNode : public Object {
...
@@ -206,7 +206,7 @@ class PassInfoNode : public Object {
v
->
Visit
(
"required"
,
&
required
);
v
->
Visit
(
"required"
,
&
required
);
}
}
static
constexpr
const
char
*
_type_key
=
"
relay
.PassInfo"
;
static
constexpr
const
char
*
_type_key
=
"
transform
.PassInfo"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
PassInfoNode
,
Object
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
PassInfoNode
,
Object
);
};
};
...
@@ -265,7 +265,7 @@ class PassNode : public Object {
...
@@ -265,7 +265,7 @@ class PassNode : public Object {
void
VisitAttrs
(
AttrVisitor
*
v
)
{}
void
VisitAttrs
(
AttrVisitor
*
v
)
{}
static
constexpr
const
char
*
_type_key
=
"
relay
.Pass"
;
static
constexpr
const
char
*
_type_key
=
"
transform
.Pass"
;
TVM_DECLARE_BASE_OBJECT_INFO
(
PassNode
,
Object
);
TVM_DECLARE_BASE_OBJECT_INFO
(
PassNode
,
Object
);
};
};
...
...
include/tvm/ir/type.h
View file @
6027412b
...
@@ -78,7 +78,7 @@ class TypeNode : public Object {
...
@@ -78,7 +78,7 @@ class TypeNode : public Object {
*/
*/
mutable
Span
span
;
mutable
Span
span
;
static
constexpr
const
char
*
_type_key
=
"
relay.
Type"
;
static
constexpr
const
char
*
_type_key
=
"Type"
;
TVM_DECLARE_BASE_OBJECT_INFO
(
TypeNode
,
Object
);
TVM_DECLARE_BASE_OBJECT_INFO
(
TypeNode
,
Object
);
};
};
...
@@ -110,7 +110,7 @@ class PrimTypeNode : public TypeNode {
...
@@ -110,7 +110,7 @@ class PrimTypeNode : public TypeNode {
v
->
Visit
(
"dtype"
,
&
dtype
);
v
->
Visit
(
"dtype"
,
&
dtype
);
}
}
static
constexpr
const
char
*
_type_key
=
"
relay.
PrimType"
;
static
constexpr
const
char
*
_type_key
=
"PrimType"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
PrimTypeNode
,
TypeNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
PrimTypeNode
,
TypeNode
);
};
};
...
@@ -175,7 +175,7 @@ class TypeVarNode : public TypeNode {
...
@@ -175,7 +175,7 @@ class TypeVarNode : public TypeNode {
v
->
Visit
(
"span"
,
&
span
);
v
->
Visit
(
"span"
,
&
span
);
}
}
static
constexpr
const
char
*
_type_key
=
"
relay.
TypeVar"
;
static
constexpr
const
char
*
_type_key
=
"TypeVar"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
TypeVarNode
,
TypeNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
TypeVarNode
,
TypeNode
);
};
};
...
@@ -215,7 +215,7 @@ class GlobalTypeVarNode : public TypeNode {
...
@@ -215,7 +215,7 @@ class GlobalTypeVarNode : public TypeNode {
v
->
Visit
(
"kind"
,
&
kind
);
v
->
Visit
(
"kind"
,
&
kind
);
}
}
static
constexpr
const
char
*
_type_key
=
"
relay.
GlobalTypeVar"
;
static
constexpr
const
char
*
_type_key
=
"GlobalTypeVar"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
GlobalTypeVarNode
,
TypeNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
GlobalTypeVarNode
,
TypeNode
);
};
};
...
@@ -251,7 +251,7 @@ class TupleTypeNode : public TypeNode {
...
@@ -251,7 +251,7 @@ class TupleTypeNode : public TypeNode {
v
->
Visit
(
"span"
,
&
span
);
v
->
Visit
(
"span"
,
&
span
);
}
}
static
constexpr
const
char
*
_type_key
=
"
relay.
TupleType"
;
static
constexpr
const
char
*
_type_key
=
"TupleType"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
TupleTypeNode
,
TypeNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
TupleTypeNode
,
TypeNode
);
};
};
...
@@ -289,7 +289,7 @@ inline Type VoidType() {
...
@@ -289,7 +289,7 @@ inline Type VoidType() {
*/
*/
class
TypeConstraintNode
:
public
TypeNode
{
class
TypeConstraintNode
:
public
TypeNode
{
public
:
public
:
static
constexpr
const
char
*
_type_key
=
"
relay.
TypeConstraint"
;
static
constexpr
const
char
*
_type_key
=
"TypeConstraint"
;
TVM_DECLARE_BASE_OBJECT_INFO
(
TypeConstraintNode
,
TypeNode
);
TVM_DECLARE_BASE_OBJECT_INFO
(
TypeConstraintNode
,
TypeNode
);
};
};
...
@@ -334,7 +334,7 @@ class FuncTypeNode : public TypeNode {
...
@@ -334,7 +334,7 @@ class FuncTypeNode : public TypeNode {
v
->
Visit
(
"span"
,
&
span
);
v
->
Visit
(
"span"
,
&
span
);
}
}
static
constexpr
const
char
*
_type_key
=
"
relay.
FuncType"
;
static
constexpr
const
char
*
_type_key
=
"FuncType"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
FuncTypeNode
,
TypeNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
FuncTypeNode
,
TypeNode
);
};
};
...
@@ -380,7 +380,7 @@ class IncompleteTypeNode : public TypeNode {
...
@@ -380,7 +380,7 @@ class IncompleteTypeNode : public TypeNode {
v
->
Visit
(
"span"
,
&
span
);
v
->
Visit
(
"span"
,
&
span
);
}
}
static
constexpr
const
char
*
_type_key
=
"
relay.
IncompleteType"
;
static
constexpr
const
char
*
_type_key
=
"IncompleteType"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
IncompleteTypeNode
,
TypeNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
IncompleteTypeNode
,
TypeNode
);
};
};
...
@@ -417,6 +417,8 @@ class RelayRefTypeNode : public TypeNode {
...
@@ -417,6 +417,8 @@ class RelayRefTypeNode : public TypeNode {
v
->
Visit
(
"span"
,
&
span
);
v
->
Visit
(
"span"
,
&
span
);
}
}
// Keep the relay prefix in the type as this type is specific
// to the relay itself.
static
constexpr
const
char
*
_type_key
=
"relay.RefType"
;
static
constexpr
const
char
*
_type_key
=
"relay.RefType"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
RelayRefTypeNode
,
TypeNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
RelayRefTypeNode
,
TypeNode
);
};
};
...
...
include/tvm/ir/type_relation.h
View file @
6027412b
...
@@ -50,7 +50,7 @@ class TypeCallNode : public TypeNode {
...
@@ -50,7 +50,7 @@ class TypeCallNode : public TypeNode {
v
->
Visit
(
"span"
,
&
span
);
v
->
Visit
(
"span"
,
&
span
);
}
}
static
constexpr
const
char
*
_type_key
=
"
relay.
TypeCall"
;
static
constexpr
const
char
*
_type_key
=
"TypeCall"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
TypeCallNode
,
TypeNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
TypeCallNode
,
TypeNode
);
};
};
...
@@ -119,7 +119,7 @@ class TypeReporterNode : public Object {
...
@@ -119,7 +119,7 @@ class TypeReporterNode : public Object {
// solver is not serializable.
// solver is not serializable.
void
VisitAttrs
(
AttrVisitor
*
v
)
{}
void
VisitAttrs
(
AttrVisitor
*
v
)
{}
static
constexpr
const
char
*
_type_key
=
"
relay.
TypeReporter"
;
static
constexpr
const
char
*
_type_key
=
"TypeReporter"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
TypeReporterNode
,
Object
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
TypeReporterNode
,
Object
);
};
};
...
@@ -195,7 +195,7 @@ class TypeRelationNode : public TypeConstraintNode {
...
@@ -195,7 +195,7 @@ class TypeRelationNode : public TypeConstraintNode {
v
->
Visit
(
"span"
,
&
span
);
v
->
Visit
(
"span"
,
&
span
);
}
}
static
constexpr
const
char
*
_type_key
=
"
relay.
TypeRelation"
;
static
constexpr
const
char
*
_type_key
=
"TypeRelation"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
TypeRelationNode
,
TypeConstraintNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
TypeRelationNode
,
TypeConstraintNode
);
};
};
...
...
python/tvm/ir/__init__.py
View file @
6027412b
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
# pylint: disable=unused-import
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
"""Common data structures across all IR variants."""
from
.base
import
SourceName
,
Span
,
Node
,
EnvFunc
,
load_json
,
save_json
from
.base
import
SourceName
,
Span
,
Node
,
EnvFunc
,
load_json
,
save_json
from
.type
import
Type
,
TypeKind
,
TypeVar
,
GlobalTypeVar
,
TupleType
from
.type
import
Type
,
TypeKind
,
PrimType
,
TypeVar
,
GlobalTypeVar
,
TupleType
from
.type
import
TypeConstraint
,
FuncType
,
IncompleteType
,
RelayRefType
from
.type
import
TypeConstraint
,
FuncType
,
IncompleteType
,
RelayRefType
from
.tensor_type
import
TensorType
from
.tensor_type
import
TensorType
from
.type_relation
import
TypeCall
,
TypeRelation
from
.type_relation
import
TypeCall
,
TypeRelation
...
...
python/tvm/ir/base.py
View file @
6027412b
...
@@ -56,7 +56,7 @@ class Node(Object):
...
@@ -56,7 +56,7 @@ class Node(Object):
return
_ffi_api
.
PrettyPrint
(
self
)
return
_ffi_api
.
PrettyPrint
(
self
)
@tvm._ffi.register_object
(
"
relay.
SourceName"
)
@tvm._ffi.register_object
(
"SourceName"
)
class
SourceName
(
Object
):
class
SourceName
(
Object
):
"""A identifier for a source location.
"""A identifier for a source location.
...
@@ -69,7 +69,7 @@ class SourceName(Object):
...
@@ -69,7 +69,7 @@ class SourceName(Object):
self
.
__init_handle_by_constructor__
(
_ffi_api
.
SourceName
,
name
)
self
.
__init_handle_by_constructor__
(
_ffi_api
.
SourceName
,
name
)
@tvm._ffi.register_object
(
"
relay.
Span"
)
@tvm._ffi.register_object
(
"Span"
)
class
Span
(
Object
):
class
Span
(
Object
):
"""Specifies a location in a source program.
"""Specifies a location in a source program.
...
...
python/tvm/ir/expr.py
View file @
6027412b
...
@@ -51,7 +51,7 @@ class RelayExpr(BaseExpr):
...
@@ -51,7 +51,7 @@ class RelayExpr(BaseExpr):
return
ret
return
ret
@tvm._ffi.register_object
(
"
relay.
GlobalVar"
)
@tvm._ffi.register_object
(
"GlobalVar"
)
class
GlobalVar
(
RelayExpr
):
class
GlobalVar
(
RelayExpr
):
"""A global variable in the IR.
"""A global variable in the IR.
...
...
python/tvm/ir/json_compact.py
View file @
6027412b
...
@@ -62,11 +62,35 @@ def create_updater_06_to_07():
...
@@ -62,11 +62,35 @@ def create_updater_06_to_07():
# set vindex to null
# set vindex to null
nodes
[
vindex
][
"type_key"
]
=
""
nodes
[
vindex
][
"type_key"
]
=
""
del
item
[
"attrs"
][
"var"
]
del
item
[
"attrs"
][
"var"
]
assert
item
[
"type_key"
]
.
startswith
(
"relay."
)
item
[
"type_key"
]
=
item
[
"type_key"
][
len
(
"relay."
):]
return
item
return
item
def
_rename
(
new_name
):
def
_convert
(
item
,
_
):
item
[
"type_key"
]
=
new_name
return
item
return
_convert
node_map
=
{
node_map
=
{
"relay.TypeVar"
:
_ftype_var
,
"relay.TypeVar"
:
_ftype_var
,
"relay.GlobalTypeVar"
:
_ftype_var
,
"relay.GlobalTypeVar"
:
_ftype_var
,
"relay.Type"
:
_rename
(
"Type"
),
"relay.TupleType"
:
_rename
(
"TupleType"
),
"relay.TypeConstraint"
:
_rename
(
"TypeConstraint"
),
"relay.FuncType"
:
_rename
(
"FuncType"
),
"relay.IncompleteType"
:
_rename
(
"IncompleteType"
),
"relay.TypeRelation"
:
_rename
(
"TypeRelation"
),
"relay.TypeCall"
:
_rename
(
"TypeCall"
),
"relay.Module"
:
_rename
(
"IRModule"
),
"relay.SourceName"
:
_rename
(
"SourceName"
),
"relay.Span"
:
_rename
(
"Span"
),
"relay.GlobalVar"
:
_rename
(
"GlobalVar"
),
"relay.Pass"
:
_rename
(
"transform.Pass"
),
"relay.PassInfo"
:
_rename
(
"transform.PassInfo"
),
"relay.PassContext"
:
_rename
(
"transform.PassContext"
),
"relay.ModulePass"
:
_rename
(
"transform.ModulePass"
),
"relay.Sequantial"
:
_rename
(
"transform.Sequantial"
),
}
}
return
create_updater
(
node_map
,
"0.6"
,
"0.7"
)
return
create_updater
(
node_map
,
"0.6"
,
"0.7"
)
...
...
python/tvm/ir/module.py
View file @
6027412b
...
@@ -24,7 +24,7 @@ from . import type as _ty
...
@@ -24,7 +24,7 @@ from . import type as _ty
from
.
import
_ffi_api
from
.
import
_ffi_api
@tvm._ffi.register_object
(
"
relay.
Module"
)
@tvm._ffi.register_object
(
"
IR
Module"
)
class
IRModule
(
Node
):
class
IRModule
(
Node
):
"""IRModule that holds functions and type definitions.
"""IRModule that holds functions and type definitions.
...
...
python/tvm/ir/transform.py
View file @
6027412b
...
@@ -27,7 +27,7 @@ from tvm.runtime import Object, ndarray as _nd
...
@@ -27,7 +27,7 @@ from tvm.runtime import Object, ndarray as _nd
from
.
import
_ffi_transform_api
from
.
import
_ffi_transform_api
@tvm._ffi.register_object
(
"
relay
.PassInfo"
)
@tvm._ffi.register_object
(
"
transform
.PassInfo"
)
class
PassInfo
(
Object
):
class
PassInfo
(
Object
):
"""The class contains the meta data required by a pass. It is the
"""The class contains the meta data required by a pass. It is the
container of information needed by running an optimization or analysis.
container of information needed by running an optimization or analysis.
...
@@ -51,7 +51,7 @@ class PassInfo(Object):
...
@@ -51,7 +51,7 @@ class PassInfo(Object):
_ffi_transform_api
.
PassInfo
,
opt_level
,
name
,
required
)
_ffi_transform_api
.
PassInfo
,
opt_level
,
name
,
required
)
@tvm._ffi.register_object
(
"
relay
.PassContext"
)
@tvm._ffi.register_object
(
"
transform
.PassContext"
)
class
PassContext
(
Object
):
class
PassContext
(
Object
):
"""The basis where a Relay optimization/analysis runs on.
"""The basis where a Relay optimization/analysis runs on.
Each pass context contains a number of auxiliary information that is used
Each pass context contains a number of auxiliary information that is used
...
@@ -112,7 +112,7 @@ class PassContext(Object):
...
@@ -112,7 +112,7 @@ class PassContext(Object):
return
_ffi_transform_api
.
GetCurrentPassContext
()
return
_ffi_transform_api
.
GetCurrentPassContext
()
@tvm._ffi.register_object
(
"
relay
.Pass"
)
@tvm._ffi.register_object
(
"
transform
.Pass"
)
class
Pass
(
Object
):
class
Pass
(
Object
):
"""The base class of all passes. All methods here are just simple wrappers
"""The base class of all passes. All methods here are just simple wrappers
that are implemented in the backend. They are defined for users to
that are implemented in the backend. They are defined for users to
...
@@ -141,7 +141,7 @@ class Pass(Object):
...
@@ -141,7 +141,7 @@ class Pass(Object):
return
_ffi_transform_api
.
RunPass
(
self
,
mod
)
return
_ffi_transform_api
.
RunPass
(
self
,
mod
)
@tvm._ffi.register_object
(
"
relay
.ModulePass"
)
@tvm._ffi.register_object
(
"
transform
.ModulePass"
)
class
ModulePass
(
Pass
):
class
ModulePass
(
Pass
):
"""A pass that works on tvm.IRModule. Users don't need to interact with
"""A pass that works on tvm.IRModule. Users don't need to interact with
this class directly. Instead, a module pass should be created through
this class directly. Instead, a module pass should be created through
...
@@ -152,7 +152,7 @@ class ModulePass(Pass):
...
@@ -152,7 +152,7 @@ class ModulePass(Pass):
"""
"""
@tvm._ffi.register_object
(
"
relay
.Sequential"
)
@tvm._ffi.register_object
(
"
transform
.Sequential"
)
class
Sequential
(
Pass
):
class
Sequential
(
Pass
):
"""A pass that works on a sequence of pass objects. Multiple passes can be
"""A pass that works on a sequence of pass objects. Multiple passes can be
executed sequentially using this class.
executed sequentially using this class.
...
...
python/tvm/ir/type.py
View file @
6027412b
...
@@ -46,7 +46,20 @@ class TypeKind(IntEnum):
...
@@ -46,7 +46,20 @@ class TypeKind(IntEnum):
TypeData
=
6
TypeData
=
6
@tvm._ffi.register_object
(
"relay.TypeVar"
)
class
PrimType
(
Type
):
"""Primitive data type in the low level IR
Parameters
----------
dtype : str
The runtime data type relates to the primtype.
"""
def
__init__
(
self
,
dtype
):
self
.
__init_handle_by_constructor__
(
_ffi_api
.
PrimType
,
dtype
)
@tvm._ffi.register_object
(
"TypeVar"
)
class
TypeVar
(
Type
):
class
TypeVar
(
Type
):
"""Type parameter in functions.
"""Type parameter in functions.
...
@@ -85,7 +98,7 @@ class TypeVar(Type):
...
@@ -85,7 +98,7 @@ class TypeVar(Type):
return
TypeCall
(
self
,
args
)
return
TypeCall
(
self
,
args
)
@tvm._ffi.register_object
(
"
relay.
GlobalTypeVar"
)
@tvm._ffi.register_object
(
"GlobalTypeVar"
)
class
GlobalTypeVar
(
Type
):
class
GlobalTypeVar
(
Type
):
"""A global type variable that is used for defining new types or type aliases.
"""A global type variable that is used for defining new types or type aliases.
...
@@ -120,7 +133,7 @@ class GlobalTypeVar(Type):
...
@@ -120,7 +133,7 @@ class GlobalTypeVar(Type):
return
TypeCall
(
self
,
args
)
return
TypeCall
(
self
,
args
)
@tvm._ffi.register_object
(
"
relay.
TupleType"
)
@tvm._ffi.register_object
(
"TupleType"
)
class
TupleType
(
Type
):
class
TupleType
(
Type
):
"""The type of tuple values.
"""The type of tuple values.
...
@@ -135,12 +148,12 @@ class TupleType(Type):
...
@@ -135,12 +148,12 @@ class TupleType(Type):
_ffi_api
.
TupleType
,
fields
)
_ffi_api
.
TupleType
,
fields
)
@tvm._ffi.register_object
(
"
relay.
TypeConstraint"
)
@tvm._ffi.register_object
(
"TypeConstraint"
)
class
TypeConstraint
(
Type
):
class
TypeConstraint
(
Type
):
"""Abstract class representing a type constraint."""
"""Abstract class representing a type constraint."""
@tvm._ffi.register_object
(
"
relay.
FuncType"
)
@tvm._ffi.register_object
(
"FuncType"
)
class
FuncType
(
Type
):
class
FuncType
(
Type
):
"""Function type.
"""Function type.
...
@@ -179,7 +192,7 @@ class FuncType(Type):
...
@@ -179,7 +192,7 @@ class FuncType(Type):
_ffi_api
.
FuncType
,
arg_types
,
ret_type
,
type_params
,
type_constraints
)
_ffi_api
.
FuncType
,
arg_types
,
ret_type
,
type_params
,
type_constraints
)
@tvm._ffi.register_object
(
"
relay.
IncompleteType"
)
@tvm._ffi.register_object
(
"IncompleteType"
)
class
IncompleteType
(
Type
):
class
IncompleteType
(
Type
):
"""Incomplete type during type inference.
"""Incomplete type during type inference.
...
...
python/tvm/ir/type_relation.py
View file @
6027412b
...
@@ -21,6 +21,7 @@ from .type import Type, TypeConstraint
...
@@ -21,6 +21,7 @@ from .type import Type, TypeConstraint
from
.
import
_ffi_api
from
.
import
_ffi_api
@tvm._ffi.register_object
(
"TypeCall"
)
class
TypeCall
(
Type
):
class
TypeCall
(
Type
):
"""Type function application.
"""Type function application.
...
@@ -41,7 +42,7 @@ class TypeCall(Type):
...
@@ -41,7 +42,7 @@ class TypeCall(Type):
self
.
__init_handle_by_constructor__
(
_ffi_api
.
TypeCall
,
func
,
args
)
self
.
__init_handle_by_constructor__
(
_ffi_api
.
TypeCall
,
func
,
args
)
@tvm._ffi.register_object
(
"
relay.
TypeRelation"
)
@tvm._ffi.register_object
(
"TypeRelation"
)
class
TypeRelation
(
TypeConstraint
):
class
TypeRelation
(
TypeConstraint
):
"""User defined type relation, it is an input-output relation on types.
"""User defined type relation, it is an input-output relation on types.
...
...
src/ir/transform.cc
View file @
6027412b
...
@@ -132,7 +132,7 @@ class ModulePassNode : public PassNode {
...
@@ -132,7 +132,7 @@ class ModulePassNode : public PassNode {
*/
*/
PassInfo
Info
()
const
override
{
return
pass_info
;
}
PassInfo
Info
()
const
override
{
return
pass_info
;
}
static
constexpr
const
char
*
_type_key
=
"
relay
.ModulePass"
;
static
constexpr
const
char
*
_type_key
=
"
transform
.ModulePass"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
ModulePassNode
,
PassNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
ModulePassNode
,
PassNode
);
};
};
...
@@ -206,7 +206,7 @@ class SequentialNode : public PassNode {
...
@@ -206,7 +206,7 @@ class SequentialNode : public PassNode {
*/
*/
IRModule
operator
()(
const
IRModule
&
mod
,
const
PassContext
&
pass_ctx
)
const
final
;
IRModule
operator
()(
const
IRModule
&
mod
,
const
PassContext
&
pass_ctx
)
const
final
;
static
constexpr
const
char
*
_type_key
=
"
relay
.Sequential"
;
static
constexpr
const
char
*
_type_key
=
"
transform
.Sequential"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
SequentialNode
,
PassNode
);
TVM_DECLARE_FINAL_OBJECT_INFO
(
SequentialNode
,
PassNode
);
};
};
...
...
tests/python/relay/test_ir_nodes.py
View file @
6027412b
...
@@ -30,13 +30,6 @@ def check_json_roundtrip(node):
...
@@ -30,13 +30,6 @@ def check_json_roundtrip(node):
assert
graph_equal
(
back
,
node
)
assert
graph_equal
(
back
,
node
)
def
test_bad_constructor
():
try
:
x
=
relay
.
ty
.
TensorType
(
"xx"
,
"xx"
)
except
tvm
.
error
.
TVMError
:
pass
# Span
# Span
def
test_span
():
def
test_span
():
span
=
relay
.
Span
(
None
,
1
,
1
)
span
=
relay
.
Span
(
None
,
1
,
1
)
...
@@ -55,71 +48,6 @@ def test_span():
...
@@ -55,71 +48,6 @@ def test_span():
assert
back
.
lineno
==
span
.
lineno
assert
back
.
lineno
==
span
.
lineno
assert
back
.
col_offset
==
span
.
col_offset
assert
back
.
col_offset
==
span
.
col_offset
# Types
def
test_tensor_type
():
shape
=
tvm
.
runtime
.
convert
([
1
,
2
,
3
])
dtype
=
'float32'
tt
=
relay
.
TensorType
(
shape
,
dtype
)
assert
tt
.
dtype
==
dtype
assert
tt
.
shape
==
shape
assert
tt
.
span
==
None
str
(
tt
)
check_json_roundtrip
(
tt
)
def
test_type_param
():
tp
=
relay
.
TypeVar
(
'name'
,
relay
.
TypeKind
.
Type
)
assert
tp
.
kind
==
relay
.
TypeKind
.
Type
# assert tp.span # TODO allow us to set span
str
(
tp
)
check_json_roundtrip
(
tp
)
def
test_func_type
():
type_params
=
tvm
.
runtime
.
convert
([])
type_constraints
=
tvm
.
runtime
.
convert
([])
# TODO: fill me in
arg_types
=
tvm
.
runtime
.
convert
([])
ret_type
=
relay
.
TensorType
((
1
,
2
,
3
),
'float32'
)
tf
=
relay
.
FuncType
(
arg_types
,
ret_type
,
type_params
,
type_constraints
)
assert
tf
.
type_params
==
type_params
assert
tf
.
type_constraints
==
type_constraints
assert
tf
.
arg_types
==
arg_types
assert
tf
.
ret_type
==
ret_type
assert
tf
.
span
==
None
# TODO make sure we can set span
str
(
tf
)
check_json_roundtrip
(
tf
)
def
test_tuple_type
():
tp
=
relay
.
TypeVar
(
'tp'
,
relay
.
TypeKind
.
Type
)
tf
=
relay
.
FuncType
(
tvm
.
runtime
.
convert
([]),
None
,
tvm
.
runtime
.
convert
([]),
tvm
.
runtime
.
convert
([]))
tt
=
relay
.
TensorType
(
tvm
.
runtime
.
convert
([
1
,
2
,
3
]),
'float32'
)
fields
=
tvm
.
runtime
.
convert
([
tp
,
tf
,
tt
])
tup_ty
=
relay
.
TupleType
(
fields
)
assert
tup_ty
.
fields
==
fields
str
(
tup_ty
)
check_json_roundtrip
(
tup_ty
)
def
test_type_relation
():
tp
=
relay
.
TypeVar
(
'tp'
,
relay
.
TypeKind
.
Type
)
tf
=
relay
.
FuncType
(
tvm
.
runtime
.
convert
([]),
None
,
tvm
.
runtime
.
convert
([]),
tvm
.
runtime
.
convert
([]))
tt
=
relay
.
TensorType
(
tvm
.
runtime
.
convert
([
1
,
2
,
3
]),
'float32'
)
args
=
tvm
.
runtime
.
convert
([
tp
,
tf
,
tt
])
num_inputs
=
2
func
=
tvm
.
ir
.
EnvFunc
.
get
(
"tvm.relay.type_relation.Broadcast"
)
attrs
=
tvm
.
ir
.
make_node
(
"attrs.TestAttrs"
,
name
=
"attr"
,
padding
=
(
3
,
4
))
tr
=
relay
.
TypeRelation
(
func
,
args
,
num_inputs
,
attrs
)
assert
tr
.
args
==
args
assert
tr
.
num_inputs
==
num_inputs
str
(
tr
)
check_json_roundtrip
(
tr
)
def
test_constant
():
def
test_constant
():
arr
=
tvm
.
nd
.
array
(
10
)
arr
=
tvm
.
nd
.
array
(
10
)
...
@@ -280,13 +208,7 @@ def test_conv2d_attrs():
...
@@ -280,13 +208,7 @@ def test_conv2d_attrs():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_bad_constructor
()
test_span
()
test_span
()
test_tensor_type
()
test_type_param
()
test_func_type
()
test_tuple_type
()
test_type_relation
()
test_constant
()
test_constant
()
test_tuple
()
test_tuple
()
test_local_var
()
test_local_var
()
...
...
tests/python/relay/test_json_compact.py
View file @
6027412b
...
@@ -17,7 +17,6 @@
...
@@ -17,7 +17,6 @@
import
tvm
import
tvm
from
tvm
import
te
from
tvm
import
te
from
tvm
import
relay
import
json
import
json
def
test_type_var
():
def
test_type_var
():
...
@@ -36,13 +35,81 @@ def test_type_var():
...
@@ -36,13 +35,81 @@ def test_type_var():
"b64ndarrays"
:
[],
"b64ndarrays"
:
[],
}
}
tvar
=
tvm
.
ir
.
load_json
(
json
.
dumps
(
data
))
tvar
=
tvm
.
ir
.
load_json
(
json
.
dumps
(
data
))
assert
isinstance
(
tvar
,
relay
.
TypeVar
)
assert
isinstance
(
tvar
,
tvm
.
ir
.
TypeVar
)
assert
tvar
.
name_hint
==
"in0"
assert
tvar
.
name_hint
==
"in0"
nodes
[
1
][
"type_key"
]
=
"relay.GlobalTypeVar"
nodes
[
1
][
"type_key"
]
=
"relay.GlobalTypeVar"
tvar
=
tvm
.
ir
.
load_json
(
json
.
dumps
(
data
))
tvar
=
tvm
.
ir
.
load_json
(
json
.
dumps
(
data
))
assert
isinstance
(
tvar
,
relay
.
GlobalTypeVar
)
assert
isinstance
(
tvar
,
tvm
.
ir
.
GlobalTypeVar
)
assert
tvar
.
name_hint
==
"in0"
assert
tvar
.
name_hint
==
"in0"
def
test_incomplete_type
():
nodes
=
[
{
"type_key"
:
""
},
{
"type_key"
:
"relay.IncompleteType"
,
"attrs"
:
{
"kind"
:
"0"
,
"span"
:
"0"
}}]
data
=
{
"root"
:
1
,
"nodes"
:
nodes
,
"attrs"
:
{
"tvm_version"
:
"0.6.0"
},
"b64ndarrays"
:
[],
}
tvar
=
tvm
.
ir
.
load_json
(
json
.
dumps
(
data
))
assert
isinstance
(
tvar
,
tvm
.
ir
.
IncompleteType
)
def
test_func_tuple_type
():
nodes
=
[
{
"type_key"
:
""
},
{
"type_key"
:
"relay.FuncType"
,
"attrs"
:
{
"arg_types"
:
"2"
,
"ret_type"
:
"3"
,
"span"
:
"0"
,
"type_constraints"
:
"6"
,
"type_params"
:
"5"
}
},
{
"type_key"
:
"Array"
},
{
"type_key"
:
"relay.TupleType"
,
"attrs"
:
{
"fields"
:
"4"
,
"span"
:
"0"
}},
{
"type_key"
:
"Array"
},
{
"type_key"
:
"Array"
},
{
"type_key"
:
"Array"
}
]
data
=
{
"root"
:
1
,
"nodes"
:
nodes
,
"attrs"
:
{
"tvm_version"
:
"0.6.0"
},
"b64ndarrays"
:
[],
}
tvar
=
tvm
.
ir
.
load_json
(
json
.
dumps
(
data
))
assert
isinstance
(
tvar
,
tvm
.
ir
.
FuncType
)
def
test_global_var
():
nodes
=
[
{
"type_key"
:
""
},
{
"type_key"
:
"relay.GlobalVar"
,
"attrs"
:
{
"_checked_type_"
:
"0"
,
"name_hint"
:
"x"
,
"span"
:
"0"
}
}
]
data
=
{
"root"
:
1
,
"nodes"
:
nodes
,
"attrs"
:
{
"tvm_version"
:
"0.6.0"
},
"b64ndarrays"
:
[],
}
tvar
=
tvm
.
ir
.
load_json
(
json
.
dumps
(
data
))
assert
isinstance
(
tvar
,
tvm
.
ir
.
GlobalVar
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_type_var
()
test_type_var
()
test_incomplete_type
()
test_func_tuple_type
()
test_global_var
()
tests/python/unittest/test_ir_type.py
0 → 100644
View file @
6027412b
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test type nodes in the IR"""
import
tvm
def
check_json_roundtrip
(
node
):
from
tvm.relay.analysis
import
graph_equal
json_str
=
tvm
.
ir
.
save_json
(
node
)
back
=
tvm
.
ir
.
load_json
(
json_str
)
assert
graph_equal
(
back
,
node
)
def
test_prim_type
():
x
=
tvm
.
ir
.
PrimType
(
"int32"
)
assert
isinstance
(
x
,
tvm
.
ir
.
PrimType
)
assert
x
.
dtype
==
"int32"
def
test_tensor_type_bad_constructor
():
try
:
x
=
tvm
.
ir
.
TensorType
(
"xx"
,
"xx"
)
except
tvm
.
error
.
TVMError
:
pass
def
test_tensor_type
():
shape
=
tvm
.
runtime
.
convert
([
1
,
2
,
3
])
dtype
=
'float32'
tt
=
tvm
.
ir
.
TensorType
(
shape
,
dtype
)
assert
tt
.
dtype
==
dtype
assert
tt
.
shape
==
shape
assert
tt
.
span
==
None
str
(
tt
)
check_json_roundtrip
(
tt
)
def
test_type_param
():
tp
=
tvm
.
ir
.
TypeVar
(
'name'
,
tvm
.
ir
.
TypeKind
.
Type
)
assert
tp
.
kind
==
tvm
.
ir
.
TypeKind
.
Type
# assert tp.span # TODO allow us to set span
str
(
tp
)
check_json_roundtrip
(
tp
)
def
test_func_type
():
type_params
=
tvm
.
runtime
.
convert
([])
type_constraints
=
tvm
.
runtime
.
convert
([])
# TODO: fill me in
arg_types
=
tvm
.
runtime
.
convert
([])
ret_type
=
tvm
.
ir
.
TensorType
((
1
,
2
,
3
),
'float32'
)
tf
=
tvm
.
ir
.
FuncType
(
arg_types
,
ret_type
,
type_params
,
type_constraints
)
assert
tf
.
type_params
==
type_params
assert
tf
.
type_constraints
==
type_constraints
assert
tf
.
arg_types
==
arg_types
assert
tf
.
ret_type
==
ret_type
assert
tf
.
span
==
None
# TODO make sure we can set span
str
(
tf
)
check_json_roundtrip
(
tf
)
def
test_tuple_type
():
tp
=
tvm
.
ir
.
TypeVar
(
'tp'
,
tvm
.
ir
.
TypeKind
.
Type
)
tf
=
tvm
.
ir
.
FuncType
([],
None
,
[],
[])
tt
=
tvm
.
ir
.
TensorType
(
tvm
.
runtime
.
convert
([
1
,
2
,
3
]),
'float32'
)
fields
=
tvm
.
runtime
.
convert
([
tp
,
tf
,
tt
])
tup_ty
=
tvm
.
ir
.
TupleType
(
fields
)
assert
tup_ty
.
fields
==
fields
str
(
tup_ty
)
check_json_roundtrip
(
tup_ty
)
def
test_type_relation
():
tp
=
tvm
.
ir
.
TypeVar
(
'tp'
,
tvm
.
ir
.
TypeKind
.
Type
)
tf
=
tvm
.
ir
.
FuncType
([],
None
,
[],
[])
tt
=
tvm
.
ir
.
TensorType
(
tvm
.
runtime
.
convert
([
1
,
2
,
3
]),
'float32'
)
args
=
tvm
.
runtime
.
convert
([
tp
,
tf
,
tt
])
num_inputs
=
2
func
=
tvm
.
ir
.
EnvFunc
.
get
(
"tvm.relay.type_relation.Broadcast"
)
attrs
=
tvm
.
ir
.
make_node
(
"attrs.TestAttrs"
,
name
=
"attr"
,
padding
=
(
3
,
4
))
tr
=
tvm
.
ir
.
TypeRelation
(
func
,
args
,
num_inputs
,
attrs
)
assert
tr
.
args
==
args
assert
tr
.
num_inputs
==
num_inputs
str
(
tr
)
check_json_roundtrip
(
tr
)
if
__name__
==
"__main__"
:
test_tensor_type_bad_constructor
()
test_tensor_type
()
test_type_param
()
test_func_type
()
test_tuple_type
()
test_type_relation
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment