Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6027412b
Unverified
Commit
6027412b
authored
Mar 15, 2020
by
Tianqi Chen
Committed by
GitHub
Mar 15, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[IR] Update the type_keys to reflect the code-org (#5074)
parent
7c5ff508
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
255 additions
and
118 deletions
+255
-118
include/tvm/ir/expr.h
+1
-1
include/tvm/ir/module.h
+1
-1
include/tvm/ir/span.h
+2
-2
include/tvm/ir/transform.h
+3
-3
include/tvm/ir/type.h
+10
-8
include/tvm/ir/type_relation.h
+3
-3
python/tvm/ir/__init__.py
+1
-1
python/tvm/ir/base.py
+2
-2
python/tvm/ir/expr.py
+1
-1
python/tvm/ir/json_compact.py
+24
-0
python/tvm/ir/module.py
+1
-1
python/tvm/ir/transform.py
+5
-5
python/tvm/ir/type.py
+19
-6
python/tvm/ir/type_relation.py
+2
-1
src/ir/transform.cc
+2
-2
tests/python/relay/test_ir_nodes.py
+0
-78
tests/python/relay/test_json_compact.py
+70
-3
tests/python/unittest/test_ir_type.py
+108
-0
No files found.
include/tvm/ir/expr.h
View file @
6027412b
...
...
@@ -196,7 +196,7 @@ class GlobalVarNode : public RelayExprNode {
v
->
Visit
(
"_checked_type_"
,
&
checked_type_
);
}
static
constexpr
const
char
*
_type_key
=
"
relay.
GlobalVar"
;
static
constexpr
const
char
*
_type_key
=
"GlobalVar"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
GlobalVarNode
,
RelayExprNode
);
};
...
...
include/tvm/ir/module.h
View file @
6027412b
...
...
@@ -226,7 +226,7 @@ class IRModuleNode : public Object {
*/
TVM_DLL
std
::
unordered_set
<
std
::
string
>
Imports
()
const
;
static
constexpr
const
char
*
_type_key
=
"
relay.
Module"
;
static
constexpr
const
char
*
_type_key
=
"
IR
Module"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
IRModuleNode
,
Object
);
private
:
...
...
include/tvm/ir/span.h
View file @
6027412b
...
...
@@ -44,7 +44,7 @@ class SourceNameNode : public Object {
// override attr visitor
void
VisitAttrs
(
AttrVisitor
*
v
)
{
v
->
Visit
(
"name"
,
&
name
);
}
static
constexpr
const
char
*
_type_key
=
"
relay.
SourceName"
;
static
constexpr
const
char
*
_type_key
=
"SourceName"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
SourceNameNode
,
Object
);
};
...
...
@@ -89,7 +89,7 @@ class SpanNode : public Object {
TVM_DLL
static
Span
make
(
SourceName
source
,
int
lineno
,
int
col_offset
);
static
constexpr
const
char
*
_type_key
=
"
relay.
Span"
;
static
constexpr
const
char
*
_type_key
=
"Span"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
SpanNode
,
Object
);
};
...
...
include/tvm/ir/transform.h
View file @
6027412b
...
...
@@ -110,7 +110,7 @@ class PassContextNode : public Object {
v
->
Visit
(
"disabled_pass"
,
&
disabled_pass
);
}
static
constexpr
const
char
*
_type_key
=
"
relay
.PassContext"
;
static
constexpr
const
char
*
_type_key
=
"
transform
.PassContext"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
PassContextNode
,
Object
);
};
...
...
@@ -206,7 +206,7 @@ class PassInfoNode : public Object {
v
->
Visit
(
"required"
,
&
required
);
}
static
constexpr
const
char
*
_type_key
=
"
relay
.PassInfo"
;
static
constexpr
const
char
*
_type_key
=
"
transform
.PassInfo"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
PassInfoNode
,
Object
);
};
...
...
@@ -265,7 +265,7 @@ class PassNode : public Object {
void
VisitAttrs
(
AttrVisitor
*
v
)
{}
static
constexpr
const
char
*
_type_key
=
"
relay
.Pass"
;
static
constexpr
const
char
*
_type_key
=
"
transform
.Pass"
;
TVM_DECLARE_BASE_OBJECT_INFO
(
PassNode
,
Object
);
};
...
...
include/tvm/ir/type.h
View file @
6027412b
...
...
@@ -78,7 +78,7 @@ class TypeNode : public Object {
*/
mutable
Span
span
;
static
constexpr
const
char
*
_type_key
=
"
relay.
Type"
;
static
constexpr
const
char
*
_type_key
=
"Type"
;
TVM_DECLARE_BASE_OBJECT_INFO
(
TypeNode
,
Object
);
};
...
...
@@ -110,7 +110,7 @@ class PrimTypeNode : public TypeNode {
v
->
Visit
(
"dtype"
,
&
dtype
);
}
static
constexpr
const
char
*
_type_key
=
"
relay.
PrimType"
;
static
constexpr
const
char
*
_type_key
=
"PrimType"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
PrimTypeNode
,
TypeNode
);
};
...
...
@@ -175,7 +175,7 @@ class TypeVarNode : public TypeNode {
v
->
Visit
(
"span"
,
&
span
);
}
static
constexpr
const
char
*
_type_key
=
"
relay.
TypeVar"
;
static
constexpr
const
char
*
_type_key
=
"TypeVar"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
TypeVarNode
,
TypeNode
);
};
...
...
@@ -215,7 +215,7 @@ class GlobalTypeVarNode : public TypeNode {
v
->
Visit
(
"kind"
,
&
kind
);
}
static
constexpr
const
char
*
_type_key
=
"
relay.
GlobalTypeVar"
;
static
constexpr
const
char
*
_type_key
=
"GlobalTypeVar"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
GlobalTypeVarNode
,
TypeNode
);
};
...
...
@@ -251,7 +251,7 @@ class TupleTypeNode : public TypeNode {
v
->
Visit
(
"span"
,
&
span
);
}
static
constexpr
const
char
*
_type_key
=
"
relay.
TupleType"
;
static
constexpr
const
char
*
_type_key
=
"TupleType"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
TupleTypeNode
,
TypeNode
);
};
...
...
@@ -289,7 +289,7 @@ inline Type VoidType() {
*/
class
TypeConstraintNode
:
public
TypeNode
{
public
:
static
constexpr
const
char
*
_type_key
=
"
relay.
TypeConstraint"
;
static
constexpr
const
char
*
_type_key
=
"TypeConstraint"
;
TVM_DECLARE_BASE_OBJECT_INFO
(
TypeConstraintNode
,
TypeNode
);
};
...
...
@@ -334,7 +334,7 @@ class FuncTypeNode : public TypeNode {
v
->
Visit
(
"span"
,
&
span
);
}
static
constexpr
const
char
*
_type_key
=
"
relay.
FuncType"
;
static
constexpr
const
char
*
_type_key
=
"FuncType"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
FuncTypeNode
,
TypeNode
);
};
...
...
@@ -380,7 +380,7 @@ class IncompleteTypeNode : public TypeNode {
v
->
Visit
(
"span"
,
&
span
);
}
static
constexpr
const
char
*
_type_key
=
"
relay.
IncompleteType"
;
static
constexpr
const
char
*
_type_key
=
"IncompleteType"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
IncompleteTypeNode
,
TypeNode
);
};
...
...
@@ -417,6 +417,8 @@ class RelayRefTypeNode : public TypeNode {
v
->
Visit
(
"span"
,
&
span
);
}
// Keep the relay prefix in the type as this type is specific
// to the relay itself.
static
constexpr
const
char
*
_type_key
=
"relay.RefType"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
RelayRefTypeNode
,
TypeNode
);
};
...
...
include/tvm/ir/type_relation.h
View file @
6027412b
...
...
@@ -50,7 +50,7 @@ class TypeCallNode : public TypeNode {
v
->
Visit
(
"span"
,
&
span
);
}
static
constexpr
const
char
*
_type_key
=
"
relay.
TypeCall"
;
static
constexpr
const
char
*
_type_key
=
"TypeCall"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
TypeCallNode
,
TypeNode
);
};
...
...
@@ -119,7 +119,7 @@ class TypeReporterNode : public Object {
// solver is not serializable.
void
VisitAttrs
(
AttrVisitor
*
v
)
{}
static
constexpr
const
char
*
_type_key
=
"
relay.
TypeReporter"
;
static
constexpr
const
char
*
_type_key
=
"TypeReporter"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
TypeReporterNode
,
Object
);
};
...
...
@@ -195,7 +195,7 @@ class TypeRelationNode : public TypeConstraintNode {
v
->
Visit
(
"span"
,
&
span
);
}
static
constexpr
const
char
*
_type_key
=
"
relay.
TypeRelation"
;
static
constexpr
const
char
*
_type_key
=
"TypeRelation"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
TypeRelationNode
,
TypeConstraintNode
);
};
...
...
python/tvm/ir/__init__.py
View file @
6027412b
...
...
@@ -17,7 +17,7 @@
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
from
.base
import
SourceName
,
Span
,
Node
,
EnvFunc
,
load_json
,
save_json
from
.type
import
Type
,
TypeKind
,
TypeVar
,
GlobalTypeVar
,
TupleType
from
.type
import
Type
,
TypeKind
,
PrimType
,
TypeVar
,
GlobalTypeVar
,
TupleType
from
.type
import
TypeConstraint
,
FuncType
,
IncompleteType
,
RelayRefType
from
.tensor_type
import
TensorType
from
.type_relation
import
TypeCall
,
TypeRelation
...
...
python/tvm/ir/base.py
View file @
6027412b
...
...
@@ -56,7 +56,7 @@ class Node(Object):
return
_ffi_api
.
PrettyPrint
(
self
)
@tvm._ffi.register_object
(
"
relay.
SourceName"
)
@tvm._ffi.register_object
(
"SourceName"
)
class
SourceName
(
Object
):
"""A identifier for a source location.
...
...
@@ -69,7 +69,7 @@ class SourceName(Object):
self
.
__init_handle_by_constructor__
(
_ffi_api
.
SourceName
,
name
)
@tvm._ffi.register_object
(
"
relay.
Span"
)
@tvm._ffi.register_object
(
"Span"
)
class
Span
(
Object
):
"""Specifies a location in a source program.
...
...
python/tvm/ir/expr.py
View file @
6027412b
...
...
@@ -51,7 +51,7 @@ class RelayExpr(BaseExpr):
return
ret
@tvm._ffi.register_object
(
"
relay.
GlobalVar"
)
@tvm._ffi.register_object
(
"GlobalVar"
)
class
GlobalVar
(
RelayExpr
):
"""A global variable in the IR.
...
...
python/tvm/ir/json_compact.py
View file @
6027412b
...
...
@@ -62,11 +62,35 @@ def create_updater_06_to_07():
# set vindex to null
nodes
[
vindex
][
"type_key"
]
=
""
del
item
[
"attrs"
][
"var"
]
assert
item
[
"type_key"
]
.
startswith
(
"relay."
)
item
[
"type_key"
]
=
item
[
"type_key"
][
len
(
"relay."
):]
return
item
def
_rename
(
new_name
):
def
_convert
(
item
,
_
):
item
[
"type_key"
]
=
new_name
return
item
return
_convert
node_map
=
{
"relay.TypeVar"
:
_ftype_var
,
"relay.GlobalTypeVar"
:
_ftype_var
,
"relay.Type"
:
_rename
(
"Type"
),
"relay.TupleType"
:
_rename
(
"TupleType"
),
"relay.TypeConstraint"
:
_rename
(
"TypeConstraint"
),
"relay.FuncType"
:
_rename
(
"FuncType"
),
"relay.IncompleteType"
:
_rename
(
"IncompleteType"
),
"relay.TypeRelation"
:
_rename
(
"TypeRelation"
),
"relay.TypeCall"
:
_rename
(
"TypeCall"
),
"relay.Module"
:
_rename
(
"IRModule"
),
"relay.SourceName"
:
_rename
(
"SourceName"
),
"relay.Span"
:
_rename
(
"Span"
),
"relay.GlobalVar"
:
_rename
(
"GlobalVar"
),
"relay.Pass"
:
_rename
(
"transform.Pass"
),
"relay.PassInfo"
:
_rename
(
"transform.PassInfo"
),
"relay.PassContext"
:
_rename
(
"transform.PassContext"
),
"relay.ModulePass"
:
_rename
(
"transform.ModulePass"
),
"relay.Sequantial"
:
_rename
(
"transform.Sequantial"
),
}
return
create_updater
(
node_map
,
"0.6"
,
"0.7"
)
...
...
python/tvm/ir/module.py
View file @
6027412b
...
...
@@ -24,7 +24,7 @@ from . import type as _ty
from
.
import
_ffi_api
@tvm._ffi.register_object
(
"
relay.
Module"
)
@tvm._ffi.register_object
(
"
IR
Module"
)
class
IRModule
(
Node
):
"""IRModule that holds functions and type definitions.
...
...
python/tvm/ir/transform.py
View file @
6027412b
...
...
@@ -27,7 +27,7 @@ from tvm.runtime import Object, ndarray as _nd
from
.
import
_ffi_transform_api
@tvm._ffi.register_object
(
"
relay
.PassInfo"
)
@tvm._ffi.register_object
(
"
transform
.PassInfo"
)
class
PassInfo
(
Object
):
"""The class contains the meta data required by a pass. It is the
container of information needed by running an optimization or analysis.
...
...
@@ -51,7 +51,7 @@ class PassInfo(Object):
_ffi_transform_api
.
PassInfo
,
opt_level
,
name
,
required
)
@tvm._ffi.register_object
(
"
relay
.PassContext"
)
@tvm._ffi.register_object
(
"
transform
.PassContext"
)
class
PassContext
(
Object
):
"""The basis where a Relay optimization/analysis runs on.
Each pass context contains a number of auxiliary information that is used
...
...
@@ -112,7 +112,7 @@ class PassContext(Object):
return
_ffi_transform_api
.
GetCurrentPassContext
()
@tvm._ffi.register_object
(
"
relay
.Pass"
)
@tvm._ffi.register_object
(
"
transform
.Pass"
)
class
Pass
(
Object
):
"""The base class of all passes. All methods here are just simple wrappers
that are implemented in the backend. They are defined for users to
...
...
@@ -141,7 +141,7 @@ class Pass(Object):
return
_ffi_transform_api
.
RunPass
(
self
,
mod
)
@tvm._ffi.register_object
(
"
relay
.ModulePass"
)
@tvm._ffi.register_object
(
"
transform
.ModulePass"
)
class
ModulePass
(
Pass
):
"""A pass that works on tvm.IRModule. Users don't need to interact with
this class directly. Instead, a module pass should be created through
...
...
@@ -152,7 +152,7 @@ class ModulePass(Pass):
"""
@tvm._ffi.register_object
(
"
relay
.Sequential"
)
@tvm._ffi.register_object
(
"
transform
.Sequential"
)
class
Sequential
(
Pass
):
"""A pass that works on a sequence of pass objects. Multiple passes can be
executed sequentially using this class.
...
...
python/tvm/ir/type.py
View file @
6027412b
...
...
@@ -46,7 +46,20 @@ class TypeKind(IntEnum):
TypeData
=
6
@tvm._ffi.register_object
(
"relay.TypeVar"
)
class
PrimType
(
Type
):
"""Primitive data type in the low level IR
Parameters
----------
dtype : str
The runtime data type relates to the primtype.
"""
def
__init__
(
self
,
dtype
):
self
.
__init_handle_by_constructor__
(
_ffi_api
.
PrimType
,
dtype
)
@tvm._ffi.register_object
(
"TypeVar"
)
class
TypeVar
(
Type
):
"""Type parameter in functions.
...
...
@@ -85,7 +98,7 @@ class TypeVar(Type):
return
TypeCall
(
self
,
args
)
@tvm._ffi.register_object
(
"
relay.
GlobalTypeVar"
)
@tvm._ffi.register_object
(
"GlobalTypeVar"
)
class
GlobalTypeVar
(
Type
):
"""A global type variable that is used for defining new types or type aliases.
...
...
@@ -120,7 +133,7 @@ class GlobalTypeVar(Type):
return
TypeCall
(
self
,
args
)
@tvm._ffi.register_object
(
"
relay.
TupleType"
)
@tvm._ffi.register_object
(
"TupleType"
)
class
TupleType
(
Type
):
"""The type of tuple values.
...
...
@@ -135,12 +148,12 @@ class TupleType(Type):
_ffi_api
.
TupleType
,
fields
)
@tvm._ffi.register_object
(
"
relay.
TypeConstraint"
)
@tvm._ffi.register_object
(
"TypeConstraint"
)
class
TypeConstraint
(
Type
):
"""Abstract class representing a type constraint."""
@tvm._ffi.register_object
(
"
relay.
FuncType"
)
@tvm._ffi.register_object
(
"FuncType"
)
class
FuncType
(
Type
):
"""Function type.
...
...
@@ -179,7 +192,7 @@ class FuncType(Type):
_ffi_api
.
FuncType
,
arg_types
,
ret_type
,
type_params
,
type_constraints
)
@tvm._ffi.register_object
(
"
relay.
IncompleteType"
)
@tvm._ffi.register_object
(
"IncompleteType"
)
class
IncompleteType
(
Type
):
"""Incomplete type during type inference.
...
...
python/tvm/ir/type_relation.py
View file @
6027412b
...
...
@@ -21,6 +21,7 @@ from .type import Type, TypeConstraint
from
.
import
_ffi_api
@tvm._ffi.register_object
(
"TypeCall"
)
class
TypeCall
(
Type
):
"""Type function application.
...
...
@@ -41,7 +42,7 @@ class TypeCall(Type):
self
.
__init_handle_by_constructor__
(
_ffi_api
.
TypeCall
,
func
,
args
)
@tvm._ffi.register_object
(
"
relay.
TypeRelation"
)
@tvm._ffi.register_object
(
"TypeRelation"
)
class
TypeRelation
(
TypeConstraint
):
"""User defined type relation, it is an input-output relation on types.
...
...
src/ir/transform.cc
View file @
6027412b
...
...
@@ -132,7 +132,7 @@ class ModulePassNode : public PassNode {
*/
PassInfo
Info
()
const
override
{
return
pass_info
;
}
static
constexpr
const
char
*
_type_key
=
"
relay
.ModulePass"
;
static
constexpr
const
char
*
_type_key
=
"
transform
.ModulePass"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
ModulePassNode
,
PassNode
);
};
...
...
@@ -206,7 +206,7 @@ class SequentialNode : public PassNode {
*/
IRModule
operator
()(
const
IRModule
&
mod
,
const
PassContext
&
pass_ctx
)
const
final
;
static
constexpr
const
char
*
_type_key
=
"
relay
.Sequential"
;
static
constexpr
const
char
*
_type_key
=
"
transform
.Sequential"
;
TVM_DECLARE_FINAL_OBJECT_INFO
(
SequentialNode
,
PassNode
);
};
...
...
tests/python/relay/test_ir_nodes.py
View file @
6027412b
...
...
@@ -30,13 +30,6 @@ def check_json_roundtrip(node):
assert
graph_equal
(
back
,
node
)
def
test_bad_constructor
():
try
:
x
=
relay
.
ty
.
TensorType
(
"xx"
,
"xx"
)
except
tvm
.
error
.
TVMError
:
pass
# Span
def
test_span
():
span
=
relay
.
Span
(
None
,
1
,
1
)
...
...
@@ -55,71 +48,6 @@ def test_span():
assert
back
.
lineno
==
span
.
lineno
assert
back
.
col_offset
==
span
.
col_offset
# Types
def
test_tensor_type
():
shape
=
tvm
.
runtime
.
convert
([
1
,
2
,
3
])
dtype
=
'float32'
tt
=
relay
.
TensorType
(
shape
,
dtype
)
assert
tt
.
dtype
==
dtype
assert
tt
.
shape
==
shape
assert
tt
.
span
==
None
str
(
tt
)
check_json_roundtrip
(
tt
)
def
test_type_param
():
tp
=
relay
.
TypeVar
(
'name'
,
relay
.
TypeKind
.
Type
)
assert
tp
.
kind
==
relay
.
TypeKind
.
Type
# assert tp.span # TODO allow us to set span
str
(
tp
)
check_json_roundtrip
(
tp
)
def
test_func_type
():
type_params
=
tvm
.
runtime
.
convert
([])
type_constraints
=
tvm
.
runtime
.
convert
([])
# TODO: fill me in
arg_types
=
tvm
.
runtime
.
convert
([])
ret_type
=
relay
.
TensorType
((
1
,
2
,
3
),
'float32'
)
tf
=
relay
.
FuncType
(
arg_types
,
ret_type
,
type_params
,
type_constraints
)
assert
tf
.
type_params
==
type_params
assert
tf
.
type_constraints
==
type_constraints
assert
tf
.
arg_types
==
arg_types
assert
tf
.
ret_type
==
ret_type
assert
tf
.
span
==
None
# TODO make sure we can set span
str
(
tf
)
check_json_roundtrip
(
tf
)
def
test_tuple_type
():
tp
=
relay
.
TypeVar
(
'tp'
,
relay
.
TypeKind
.
Type
)
tf
=
relay
.
FuncType
(
tvm
.
runtime
.
convert
([]),
None
,
tvm
.
runtime
.
convert
([]),
tvm
.
runtime
.
convert
([]))
tt
=
relay
.
TensorType
(
tvm
.
runtime
.
convert
([
1
,
2
,
3
]),
'float32'
)
fields
=
tvm
.
runtime
.
convert
([
tp
,
tf
,
tt
])
tup_ty
=
relay
.
TupleType
(
fields
)
assert
tup_ty
.
fields
==
fields
str
(
tup_ty
)
check_json_roundtrip
(
tup_ty
)
def
test_type_relation
():
tp
=
relay
.
TypeVar
(
'tp'
,
relay
.
TypeKind
.
Type
)
tf
=
relay
.
FuncType
(
tvm
.
runtime
.
convert
([]),
None
,
tvm
.
runtime
.
convert
([]),
tvm
.
runtime
.
convert
([]))
tt
=
relay
.
TensorType
(
tvm
.
runtime
.
convert
([
1
,
2
,
3
]),
'float32'
)
args
=
tvm
.
runtime
.
convert
([
tp
,
tf
,
tt
])
num_inputs
=
2
func
=
tvm
.
ir
.
EnvFunc
.
get
(
"tvm.relay.type_relation.Broadcast"
)
attrs
=
tvm
.
ir
.
make_node
(
"attrs.TestAttrs"
,
name
=
"attr"
,
padding
=
(
3
,
4
))
tr
=
relay
.
TypeRelation
(
func
,
args
,
num_inputs
,
attrs
)
assert
tr
.
args
==
args
assert
tr
.
num_inputs
==
num_inputs
str
(
tr
)
check_json_roundtrip
(
tr
)
def
test_constant
():
arr
=
tvm
.
nd
.
array
(
10
)
...
...
@@ -280,13 +208,7 @@ def test_conv2d_attrs():
if
__name__
==
"__main__"
:
test_bad_constructor
()
test_span
()
test_tensor_type
()
test_type_param
()
test_func_type
()
test_tuple_type
()
test_type_relation
()
test_constant
()
test_tuple
()
test_local_var
()
...
...
tests/python/relay/test_json_compact.py
View file @
6027412b
...
...
@@ -17,7 +17,6 @@
import
tvm
from
tvm
import
te
from
tvm
import
relay
import
json
def
test_type_var
():
...
...
@@ -36,13 +35,81 @@ def test_type_var():
"b64ndarrays"
:
[],
}
tvar
=
tvm
.
ir
.
load_json
(
json
.
dumps
(
data
))
assert
isinstance
(
tvar
,
relay
.
TypeVar
)
assert
isinstance
(
tvar
,
tvm
.
ir
.
TypeVar
)
assert
tvar
.
name_hint
==
"in0"
nodes
[
1
][
"type_key"
]
=
"relay.GlobalTypeVar"
tvar
=
tvm
.
ir
.
load_json
(
json
.
dumps
(
data
))
assert
isinstance
(
tvar
,
relay
.
GlobalTypeVar
)
assert
isinstance
(
tvar
,
tvm
.
ir
.
GlobalTypeVar
)
assert
tvar
.
name_hint
==
"in0"
def
test_incomplete_type
():
nodes
=
[
{
"type_key"
:
""
},
{
"type_key"
:
"relay.IncompleteType"
,
"attrs"
:
{
"kind"
:
"0"
,
"span"
:
"0"
}}]
data
=
{
"root"
:
1
,
"nodes"
:
nodes
,
"attrs"
:
{
"tvm_version"
:
"0.6.0"
},
"b64ndarrays"
:
[],
}
tvar
=
tvm
.
ir
.
load_json
(
json
.
dumps
(
data
))
assert
isinstance
(
tvar
,
tvm
.
ir
.
IncompleteType
)
def
test_func_tuple_type
():
nodes
=
[
{
"type_key"
:
""
},
{
"type_key"
:
"relay.FuncType"
,
"attrs"
:
{
"arg_types"
:
"2"
,
"ret_type"
:
"3"
,
"span"
:
"0"
,
"type_constraints"
:
"6"
,
"type_params"
:
"5"
}
},
{
"type_key"
:
"Array"
},
{
"type_key"
:
"relay.TupleType"
,
"attrs"
:
{
"fields"
:
"4"
,
"span"
:
"0"
}},
{
"type_key"
:
"Array"
},
{
"type_key"
:
"Array"
},
{
"type_key"
:
"Array"
}
]
data
=
{
"root"
:
1
,
"nodes"
:
nodes
,
"attrs"
:
{
"tvm_version"
:
"0.6.0"
},
"b64ndarrays"
:
[],
}
tvar
=
tvm
.
ir
.
load_json
(
json
.
dumps
(
data
))
assert
isinstance
(
tvar
,
tvm
.
ir
.
FuncType
)
def
test_global_var
():
nodes
=
[
{
"type_key"
:
""
},
{
"type_key"
:
"relay.GlobalVar"
,
"attrs"
:
{
"_checked_type_"
:
"0"
,
"name_hint"
:
"x"
,
"span"
:
"0"
}
}
]
data
=
{
"root"
:
1
,
"nodes"
:
nodes
,
"attrs"
:
{
"tvm_version"
:
"0.6.0"
},
"b64ndarrays"
:
[],
}
tvar
=
tvm
.
ir
.
load_json
(
json
.
dumps
(
data
))
assert
isinstance
(
tvar
,
tvm
.
ir
.
GlobalVar
)
if
__name__
==
"__main__"
:
test_type_var
()
test_incomplete_type
()
test_func_tuple_type
()
test_global_var
()
tests/python/unittest/test_ir_type.py
0 → 100644
View file @
6027412b
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test type nodes in the IR"""
import
tvm
def
check_json_roundtrip
(
node
):
from
tvm.relay.analysis
import
graph_equal
json_str
=
tvm
.
ir
.
save_json
(
node
)
back
=
tvm
.
ir
.
load_json
(
json_str
)
assert
graph_equal
(
back
,
node
)
def
test_prim_type
():
x
=
tvm
.
ir
.
PrimType
(
"int32"
)
assert
isinstance
(
x
,
tvm
.
ir
.
PrimType
)
assert
x
.
dtype
==
"int32"
def
test_tensor_type_bad_constructor
():
try
:
x
=
tvm
.
ir
.
TensorType
(
"xx"
,
"xx"
)
except
tvm
.
error
.
TVMError
:
pass
def
test_tensor_type
():
shape
=
tvm
.
runtime
.
convert
([
1
,
2
,
3
])
dtype
=
'float32'
tt
=
tvm
.
ir
.
TensorType
(
shape
,
dtype
)
assert
tt
.
dtype
==
dtype
assert
tt
.
shape
==
shape
assert
tt
.
span
==
None
str
(
tt
)
check_json_roundtrip
(
tt
)
def
test_type_param
():
tp
=
tvm
.
ir
.
TypeVar
(
'name'
,
tvm
.
ir
.
TypeKind
.
Type
)
assert
tp
.
kind
==
tvm
.
ir
.
TypeKind
.
Type
# assert tp.span # TODO allow us to set span
str
(
tp
)
check_json_roundtrip
(
tp
)
def
test_func_type
():
type_params
=
tvm
.
runtime
.
convert
([])
type_constraints
=
tvm
.
runtime
.
convert
([])
# TODO: fill me in
arg_types
=
tvm
.
runtime
.
convert
([])
ret_type
=
tvm
.
ir
.
TensorType
((
1
,
2
,
3
),
'float32'
)
tf
=
tvm
.
ir
.
FuncType
(
arg_types
,
ret_type
,
type_params
,
type_constraints
)
assert
tf
.
type_params
==
type_params
assert
tf
.
type_constraints
==
type_constraints
assert
tf
.
arg_types
==
arg_types
assert
tf
.
ret_type
==
ret_type
assert
tf
.
span
==
None
# TODO make sure we can set span
str
(
tf
)
check_json_roundtrip
(
tf
)
def
test_tuple_type
():
tp
=
tvm
.
ir
.
TypeVar
(
'tp'
,
tvm
.
ir
.
TypeKind
.
Type
)
tf
=
tvm
.
ir
.
FuncType
([],
None
,
[],
[])
tt
=
tvm
.
ir
.
TensorType
(
tvm
.
runtime
.
convert
([
1
,
2
,
3
]),
'float32'
)
fields
=
tvm
.
runtime
.
convert
([
tp
,
tf
,
tt
])
tup_ty
=
tvm
.
ir
.
TupleType
(
fields
)
assert
tup_ty
.
fields
==
fields
str
(
tup_ty
)
check_json_roundtrip
(
tup_ty
)
def
test_type_relation
():
tp
=
tvm
.
ir
.
TypeVar
(
'tp'
,
tvm
.
ir
.
TypeKind
.
Type
)
tf
=
tvm
.
ir
.
FuncType
([],
None
,
[],
[])
tt
=
tvm
.
ir
.
TensorType
(
tvm
.
runtime
.
convert
([
1
,
2
,
3
]),
'float32'
)
args
=
tvm
.
runtime
.
convert
([
tp
,
tf
,
tt
])
num_inputs
=
2
func
=
tvm
.
ir
.
EnvFunc
.
get
(
"tvm.relay.type_relation.Broadcast"
)
attrs
=
tvm
.
ir
.
make_node
(
"attrs.TestAttrs"
,
name
=
"attr"
,
padding
=
(
3
,
4
))
tr
=
tvm
.
ir
.
TypeRelation
(
func
,
args
,
num_inputs
,
attrs
)
assert
tr
.
args
==
args
assert
tr
.
num_inputs
==
num_inputs
str
(
tr
)
check_json_roundtrip
(
tr
)
if
__name__
==
"__main__"
:
test_tensor_type_bad_constructor
()
test_tensor_type
()
test_type_param
()
test_func_type
()
test_tuple_type
()
test_type_relation
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment