Unverified Commit 6027412b by Tianqi Chen Committed by GitHub

[IR] Update the type_keys to reflect the code-org (#5074)

parent 7c5ff508
...@@ -196,7 +196,7 @@ class GlobalVarNode : public RelayExprNode { ...@@ -196,7 +196,7 @@ class GlobalVarNode : public RelayExprNode {
v->Visit("_checked_type_", &checked_type_); v->Visit("_checked_type_", &checked_type_);
} }
static constexpr const char* _type_key = "relay.GlobalVar"; static constexpr const char* _type_key = "GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode); TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
}; };
......
...@@ -226,7 +226,7 @@ class IRModuleNode : public Object { ...@@ -226,7 +226,7 @@ class IRModuleNode : public Object {
*/ */
TVM_DLL std::unordered_set<std::string> Imports() const; TVM_DLL std::unordered_set<std::string> Imports() const;
static constexpr const char* _type_key = "relay.Module"; static constexpr const char* _type_key = "IRModule";
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
private: private:
......
...@@ -44,7 +44,7 @@ class SourceNameNode : public Object { ...@@ -44,7 +44,7 @@ class SourceNameNode : public Object {
// override attr visitor // override attr visitor
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); } void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
static constexpr const char* _type_key = "relay.SourceName"; static constexpr const char* _type_key = "SourceName";
TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
}; };
...@@ -89,7 +89,7 @@ class SpanNode : public Object { ...@@ -89,7 +89,7 @@ class SpanNode : public Object {
TVM_DLL static Span make(SourceName source, int lineno, int col_offset); TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
static constexpr const char* _type_key = "relay.Span"; static constexpr const char* _type_key = "Span";
TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
}; };
......
...@@ -110,7 +110,7 @@ class PassContextNode : public Object { ...@@ -110,7 +110,7 @@ class PassContextNode : public Object {
v->Visit("disabled_pass", &disabled_pass); v->Visit("disabled_pass", &disabled_pass);
} }
static constexpr const char* _type_key = "relay.PassContext"; static constexpr const char* _type_key = "transform.PassContext";
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
}; };
...@@ -206,7 +206,7 @@ class PassInfoNode : public Object { ...@@ -206,7 +206,7 @@ class PassInfoNode : public Object {
v->Visit("required", &required); v->Visit("required", &required);
} }
static constexpr const char* _type_key = "relay.PassInfo"; static constexpr const char* _type_key = "transform.PassInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object);
}; };
...@@ -265,7 +265,7 @@ class PassNode : public Object { ...@@ -265,7 +265,7 @@ class PassNode : public Object {
void VisitAttrs(AttrVisitor* v) {} void VisitAttrs(AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.Pass"; static constexpr const char* _type_key = "transform.Pass";
TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object);
}; };
......
...@@ -78,7 +78,7 @@ class TypeNode : public Object { ...@@ -78,7 +78,7 @@ class TypeNode : public Object {
*/ */
mutable Span span; mutable Span span;
static constexpr const char* _type_key = "relay.Type"; static constexpr const char* _type_key = "Type";
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
}; };
...@@ -110,7 +110,7 @@ class PrimTypeNode : public TypeNode { ...@@ -110,7 +110,7 @@ class PrimTypeNode : public TypeNode {
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
} }
static constexpr const char* _type_key = "relay.PrimType"; static constexpr const char* _type_key = "PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
}; };
...@@ -175,7 +175,7 @@ class TypeVarNode : public TypeNode { ...@@ -175,7 +175,7 @@ class TypeVarNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
static constexpr const char* _type_key = "relay.TypeVar"; static constexpr const char* _type_key = "TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
}; };
...@@ -215,7 +215,7 @@ class GlobalTypeVarNode : public TypeNode { ...@@ -215,7 +215,7 @@ class GlobalTypeVarNode : public TypeNode {
v->Visit("kind", &kind); v->Visit("kind", &kind);
} }
static constexpr const char* _type_key = "relay.GlobalTypeVar"; static constexpr const char* _type_key = "GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
}; };
...@@ -251,7 +251,7 @@ class TupleTypeNode : public TypeNode { ...@@ -251,7 +251,7 @@ class TupleTypeNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
static constexpr const char* _type_key = "relay.TupleType"; static constexpr const char* _type_key = "TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
}; };
...@@ -289,7 +289,7 @@ inline Type VoidType() { ...@@ -289,7 +289,7 @@ inline Type VoidType() {
*/ */
class TypeConstraintNode : public TypeNode { class TypeConstraintNode : public TypeNode {
public: public:
static constexpr const char* _type_key = "relay.TypeConstraint"; static constexpr const char* _type_key = "TypeConstraint";
TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode); TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
}; };
...@@ -334,7 +334,7 @@ class FuncTypeNode : public TypeNode { ...@@ -334,7 +334,7 @@ class FuncTypeNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
static constexpr const char* _type_key = "relay.FuncType"; static constexpr const char* _type_key = "FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
}; };
...@@ -380,7 +380,7 @@ class IncompleteTypeNode : public TypeNode { ...@@ -380,7 +380,7 @@ class IncompleteTypeNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
static constexpr const char* _type_key = "relay.IncompleteType"; static constexpr const char* _type_key = "IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
}; };
...@@ -417,6 +417,8 @@ class RelayRefTypeNode : public TypeNode { ...@@ -417,6 +417,8 @@ class RelayRefTypeNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
// Keep the relay prefix in the type as this type is specific
// to the relay itself.
static constexpr const char* _type_key = "relay.RefType"; static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode);
}; };
......
...@@ -50,7 +50,7 @@ class TypeCallNode : public TypeNode { ...@@ -50,7 +50,7 @@ class TypeCallNode : public TypeNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
static constexpr const char* _type_key = "relay.TypeCall"; static constexpr const char* _type_key = "TypeCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
}; };
...@@ -119,7 +119,7 @@ class TypeReporterNode : public Object { ...@@ -119,7 +119,7 @@ class TypeReporterNode : public Object {
// solver is not serializable. // solver is not serializable.
void VisitAttrs(AttrVisitor* v) {} void VisitAttrs(AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.TypeReporter"; static constexpr const char* _type_key = "TypeReporter";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object);
}; };
...@@ -195,7 +195,7 @@ class TypeRelationNode : public TypeConstraintNode { ...@@ -195,7 +195,7 @@ class TypeRelationNode : public TypeConstraintNode {
v->Visit("span", &span); v->Visit("span", &span);
} }
static constexpr const char* _type_key = "relay.TypeRelation"; static constexpr const char* _type_key = "TypeRelation";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode); TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
}; };
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# pylint: disable=unused-import # pylint: disable=unused-import
"""Common data structures across all IR variants.""" """Common data structures across all IR variants."""
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType from .type import Type, TypeKind, PrimType, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation from .type_relation import TypeCall, TypeRelation
......
...@@ -56,7 +56,7 @@ class Node(Object): ...@@ -56,7 +56,7 @@ class Node(Object):
return _ffi_api.PrettyPrint(self) return _ffi_api.PrettyPrint(self)
@tvm._ffi.register_object("relay.SourceName") @tvm._ffi.register_object("SourceName")
class SourceName(Object): class SourceName(Object):
"""A identifier for a source location. """A identifier for a source location.
...@@ -69,7 +69,7 @@ class SourceName(Object): ...@@ -69,7 +69,7 @@ class SourceName(Object):
self.__init_handle_by_constructor__(_ffi_api.SourceName, name) self.__init_handle_by_constructor__(_ffi_api.SourceName, name)
@tvm._ffi.register_object("relay.Span") @tvm._ffi.register_object("Span")
class Span(Object): class Span(Object):
"""Specifies a location in a source program. """Specifies a location in a source program.
......
...@@ -51,7 +51,7 @@ class RelayExpr(BaseExpr): ...@@ -51,7 +51,7 @@ class RelayExpr(BaseExpr):
return ret return ret
@tvm._ffi.register_object("relay.GlobalVar") @tvm._ffi.register_object("GlobalVar")
class GlobalVar(RelayExpr): class GlobalVar(RelayExpr):
"""A global variable in the IR. """A global variable in the IR.
......
...@@ -62,11 +62,35 @@ def create_updater_06_to_07(): ...@@ -62,11 +62,35 @@ def create_updater_06_to_07():
# set vindex to null # set vindex to null
nodes[vindex]["type_key"] = "" nodes[vindex]["type_key"] = ""
del item["attrs"]["var"] del item["attrs"]["var"]
assert item["type_key"].startswith("relay.")
item["type_key"] = item["type_key"][len("relay."):]
return item return item
def _rename(new_name):
def _convert(item, _):
item["type_key"] = new_name
return item
return _convert
node_map = { node_map = {
"relay.TypeVar": _ftype_var, "relay.TypeVar": _ftype_var,
"relay.GlobalTypeVar": _ftype_var, "relay.GlobalTypeVar": _ftype_var,
"relay.Type": _rename("Type"),
"relay.TupleType": _rename("TupleType"),
"relay.TypeConstraint": _rename("TypeConstraint"),
"relay.FuncType": _rename("FuncType"),
"relay.IncompleteType": _rename("IncompleteType"),
"relay.TypeRelation": _rename("TypeRelation"),
"relay.TypeCall": _rename("TypeCall"),
"relay.Module": _rename("IRModule"),
"relay.SourceName": _rename("SourceName"),
"relay.Span": _rename("Span"),
"relay.GlobalVar": _rename("GlobalVar"),
"relay.Pass": _rename("transform.Pass"),
"relay.PassInfo": _rename("transform.PassInfo"),
"relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequantial": _rename("transform.Sequantial"),
} }
return create_updater(node_map, "0.6", "0.7") return create_updater(node_map, "0.6", "0.7")
......
...@@ -24,7 +24,7 @@ from . import type as _ty ...@@ -24,7 +24,7 @@ from . import type as _ty
from . import _ffi_api from . import _ffi_api
@tvm._ffi.register_object("relay.Module") @tvm._ffi.register_object("IRModule")
class IRModule(Node): class IRModule(Node):
"""IRModule that holds functions and type definitions. """IRModule that holds functions and type definitions.
......
...@@ -27,7 +27,7 @@ from tvm.runtime import Object, ndarray as _nd ...@@ -27,7 +27,7 @@ from tvm.runtime import Object, ndarray as _nd
from . import _ffi_transform_api from . import _ffi_transform_api
@tvm._ffi.register_object("relay.PassInfo") @tvm._ffi.register_object("transform.PassInfo")
class PassInfo(Object): class PassInfo(Object):
"""The class contains the meta data required by a pass. It is the """The class contains the meta data required by a pass. It is the
container of information needed by running an optimization or analysis. container of information needed by running an optimization or analysis.
...@@ -51,7 +51,7 @@ class PassInfo(Object): ...@@ -51,7 +51,7 @@ class PassInfo(Object):
_ffi_transform_api.PassInfo, opt_level, name, required) _ffi_transform_api.PassInfo, opt_level, name, required)
@tvm._ffi.register_object("relay.PassContext") @tvm._ffi.register_object("transform.PassContext")
class PassContext(Object): class PassContext(Object):
"""The basis where a Relay optimization/analysis runs on. """The basis where a Relay optimization/analysis runs on.
Each pass context contains a number of auxiliary information that is used Each pass context contains a number of auxiliary information that is used
...@@ -112,7 +112,7 @@ class PassContext(Object): ...@@ -112,7 +112,7 @@ class PassContext(Object):
return _ffi_transform_api.GetCurrentPassContext() return _ffi_transform_api.GetCurrentPassContext()
@tvm._ffi.register_object("relay.Pass") @tvm._ffi.register_object("transform.Pass")
class Pass(Object): class Pass(Object):
"""The base class of all passes. All methods here are just simple wrappers """The base class of all passes. All methods here are just simple wrappers
that are implemented in the backend. They are defined for users to that are implemented in the backend. They are defined for users to
...@@ -141,7 +141,7 @@ class Pass(Object): ...@@ -141,7 +141,7 @@ class Pass(Object):
return _ffi_transform_api.RunPass(self, mod) return _ffi_transform_api.RunPass(self, mod)
@tvm._ffi.register_object("relay.ModulePass") @tvm._ffi.register_object("transform.ModulePass")
class ModulePass(Pass): class ModulePass(Pass):
"""A pass that works on tvm.IRModule. Users don't need to interact with """A pass that works on tvm.IRModule. Users don't need to interact with
this class directly. Instead, a module pass should be created through this class directly. Instead, a module pass should be created through
...@@ -152,7 +152,7 @@ class ModulePass(Pass): ...@@ -152,7 +152,7 @@ class ModulePass(Pass):
""" """
@tvm._ffi.register_object("relay.Sequential") @tvm._ffi.register_object("transform.Sequential")
class Sequential(Pass): class Sequential(Pass):
"""A pass that works on a sequence of pass objects. Multiple passes can be """A pass that works on a sequence of pass objects. Multiple passes can be
executed sequentially using this class. executed sequentially using this class.
......
...@@ -46,7 +46,20 @@ class TypeKind(IntEnum): ...@@ -46,7 +46,20 @@ class TypeKind(IntEnum):
TypeData = 6 TypeData = 6
@tvm._ffi.register_object("relay.TypeVar") class PrimType(Type):
"""Primitive data type in the low level IR
Parameters
----------
dtype : str
The runtime data type relates to the primtype.
"""
def __init__(self, dtype):
self.__init_handle_by_constructor__(
_ffi_api.PrimType, dtype)
@tvm._ffi.register_object("TypeVar")
class TypeVar(Type): class TypeVar(Type):
"""Type parameter in functions. """Type parameter in functions.
...@@ -85,7 +98,7 @@ class TypeVar(Type): ...@@ -85,7 +98,7 @@ class TypeVar(Type):
return TypeCall(self, args) return TypeCall(self, args)
@tvm._ffi.register_object("relay.GlobalTypeVar") @tvm._ffi.register_object("GlobalTypeVar")
class GlobalTypeVar(Type): class GlobalTypeVar(Type):
"""A global type variable that is used for defining new types or type aliases. """A global type variable that is used for defining new types or type aliases.
...@@ -120,7 +133,7 @@ class GlobalTypeVar(Type): ...@@ -120,7 +133,7 @@ class GlobalTypeVar(Type):
return TypeCall(self, args) return TypeCall(self, args)
@tvm._ffi.register_object("relay.TupleType") @tvm._ffi.register_object("TupleType")
class TupleType(Type): class TupleType(Type):
"""The type of tuple values. """The type of tuple values.
...@@ -135,12 +148,12 @@ class TupleType(Type): ...@@ -135,12 +148,12 @@ class TupleType(Type):
_ffi_api.TupleType, fields) _ffi_api.TupleType, fields)
@tvm._ffi.register_object("relay.TypeConstraint") @tvm._ffi.register_object("TypeConstraint")
class TypeConstraint(Type): class TypeConstraint(Type):
"""Abstract class representing a type constraint.""" """Abstract class representing a type constraint."""
@tvm._ffi.register_object("relay.FuncType") @tvm._ffi.register_object("FuncType")
class FuncType(Type): class FuncType(Type):
"""Function type. """Function type.
...@@ -179,7 +192,7 @@ class FuncType(Type): ...@@ -179,7 +192,7 @@ class FuncType(Type):
_ffi_api.FuncType, arg_types, ret_type, type_params, type_constraints) _ffi_api.FuncType, arg_types, ret_type, type_params, type_constraints)
@tvm._ffi.register_object("relay.IncompleteType") @tvm._ffi.register_object("IncompleteType")
class IncompleteType(Type): class IncompleteType(Type):
"""Incomplete type during type inference. """Incomplete type during type inference.
......
...@@ -21,6 +21,7 @@ from .type import Type, TypeConstraint ...@@ -21,6 +21,7 @@ from .type import Type, TypeConstraint
from . import _ffi_api from . import _ffi_api
@tvm._ffi.register_object("TypeCall")
class TypeCall(Type): class TypeCall(Type):
"""Type function application. """Type function application.
...@@ -41,7 +42,7 @@ class TypeCall(Type): ...@@ -41,7 +42,7 @@ class TypeCall(Type):
self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args) self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args)
@tvm._ffi.register_object("relay.TypeRelation") @tvm._ffi.register_object("TypeRelation")
class TypeRelation(TypeConstraint): class TypeRelation(TypeConstraint):
"""User defined type relation, it is an input-output relation on types. """User defined type relation, it is an input-output relation on types.
......
...@@ -132,7 +132,7 @@ class ModulePassNode : public PassNode { ...@@ -132,7 +132,7 @@ class ModulePassNode : public PassNode {
*/ */
PassInfo Info() const override { return pass_info; } PassInfo Info() const override { return pass_info; }
static constexpr const char* _type_key = "relay.ModulePass"; static constexpr const char* _type_key = "transform.ModulePass";
TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode); TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode);
}; };
...@@ -206,7 +206,7 @@ class SequentialNode : public PassNode { ...@@ -206,7 +206,7 @@ class SequentialNode : public PassNode {
*/ */
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final; IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
static constexpr const char* _type_key = "relay.Sequential"; static constexpr const char* _type_key = "transform.Sequential";
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
}; };
......
...@@ -30,13 +30,6 @@ def check_json_roundtrip(node): ...@@ -30,13 +30,6 @@ def check_json_roundtrip(node):
assert graph_equal(back, node) assert graph_equal(back, node)
def test_bad_constructor():
try:
x = relay.ty.TensorType("xx", "xx")
except tvm.error.TVMError:
pass
# Span # Span
def test_span(): def test_span():
span = relay.Span(None, 1, 1) span = relay.Span(None, 1, 1)
...@@ -55,71 +48,6 @@ def test_span(): ...@@ -55,71 +48,6 @@ def test_span():
assert back.lineno == span.lineno assert back.lineno == span.lineno
assert back.col_offset == span.col_offset assert back.col_offset == span.col_offset
# Types
def test_tensor_type():
shape = tvm.runtime.convert([1, 2, 3])
dtype = 'float32'
tt = relay.TensorType(shape, dtype)
assert tt.dtype == dtype
assert tt.shape == shape
assert tt.span == None
str(tt)
check_json_roundtrip(tt)
def test_type_param():
tp = relay.TypeVar('name', relay.TypeKind.Type)
assert tp.kind == relay.TypeKind.Type
# assert tp.span # TODO allow us to set span
str(tp)
check_json_roundtrip(tp)
def test_func_type():
type_params = tvm.runtime.convert([])
type_constraints = tvm.runtime.convert([]) # TODO: fill me in
arg_types = tvm.runtime.convert([])
ret_type = relay.TensorType((1, 2, 3), 'float32')
tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
assert tf.type_params == type_params
assert tf.type_constraints == type_constraints
assert tf.arg_types == arg_types
assert tf.ret_type == ret_type
assert tf.span == None
# TODO make sure we can set span
str(tf)
check_json_roundtrip(tf)
def test_tuple_type():
tp = relay.TypeVar('tp', relay.TypeKind.Type)
tf = relay.FuncType(tvm.runtime.convert([]), None, tvm.runtime.convert([]), tvm.runtime.convert([]))
tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
fields = tvm.runtime.convert([tp, tf, tt])
tup_ty = relay.TupleType(fields)
assert tup_ty.fields == fields
str(tup_ty)
check_json_roundtrip(tup_ty)
def test_type_relation():
tp = relay.TypeVar('tp', relay.TypeKind.Type)
tf = relay.FuncType(tvm.runtime.convert([]), None, tvm.runtime.convert([]), tvm.runtime.convert([]))
tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
args = tvm.runtime.convert([tp, tf, tt])
num_inputs = 2
func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
tr = relay.TypeRelation(func, args, num_inputs, attrs)
assert tr.args == args
assert tr.num_inputs == num_inputs
str(tr)
check_json_roundtrip(tr)
def test_constant(): def test_constant():
arr = tvm.nd.array(10) arr = tvm.nd.array(10)
...@@ -280,13 +208,7 @@ def test_conv2d_attrs(): ...@@ -280,13 +208,7 @@ def test_conv2d_attrs():
if __name__ == "__main__": if __name__ == "__main__":
test_bad_constructor()
test_span() test_span()
test_tensor_type()
test_type_param()
test_func_type()
test_tuple_type()
test_type_relation()
test_constant() test_constant()
test_tuple() test_tuple()
test_local_var() test_local_var()
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import tvm import tvm
from tvm import te from tvm import te
from tvm import relay
import json import json
def test_type_var(): def test_type_var():
...@@ -36,13 +35,81 @@ def test_type_var(): ...@@ -36,13 +35,81 @@ def test_type_var():
"b64ndarrays": [], "b64ndarrays": [],
} }
tvar = tvm.ir.load_json(json.dumps(data)) tvar = tvm.ir.load_json(json.dumps(data))
assert isinstance(tvar, relay.TypeVar) assert isinstance(tvar, tvm.ir.TypeVar)
assert tvar.name_hint == "in0" assert tvar.name_hint == "in0"
nodes[1]["type_key"] = "relay.GlobalTypeVar" nodes[1]["type_key"] = "relay.GlobalTypeVar"
tvar = tvm.ir.load_json(json.dumps(data)) tvar = tvm.ir.load_json(json.dumps(data))
assert isinstance(tvar, relay.GlobalTypeVar) assert isinstance(tvar, tvm.ir.GlobalTypeVar)
assert tvar.name_hint == "in0" assert tvar.name_hint == "in0"
def test_incomplete_type():
nodes = [
{"type_key": ""},
{"type_key": "relay.IncompleteType",
"attrs": {"kind": "0", "span": "0"}}]
data = {
"root" : 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
tvar = tvm.ir.load_json(json.dumps(data))
assert isinstance(tvar, tvm.ir.IncompleteType)
def test_func_tuple_type():
nodes = [
{"type_key": ""},
{"type_key": "relay.FuncType",
"attrs": {
"arg_types": "2",
"ret_type": "3",
"span": "0",
"type_constraints": "6",
"type_params": "5"
}
},
{"type_key": "Array"},
{"type_key": "relay.TupleType",
"attrs": { "fields": "4", "span": "0" }},
{"type_key": "Array"},
{"type_key": "Array"},
{"type_key": "Array"}
]
data = {
"root" : 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
tvar = tvm.ir.load_json(json.dumps(data))
assert isinstance(tvar, tvm.ir.FuncType)
def test_global_var():
nodes = [
{"type_key": ""},
{"type_key": "relay.GlobalVar",
"attrs": {
"_checked_type_": "0",
"name_hint": "x",
"span": "0"
}
}
]
data = {
"root" : 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
tvar = tvm.ir.load_json(json.dumps(data))
assert isinstance(tvar, tvm.ir.GlobalVar)
if __name__ == "__main__": if __name__ == "__main__":
test_type_var() test_type_var()
test_incomplete_type()
test_func_tuple_type()
test_global_var()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test type nodes in the IR"""
import tvm
def check_json_roundtrip(node):
from tvm.relay.analysis import graph_equal
json_str = tvm.ir.save_json(node)
back = tvm.ir.load_json(json_str)
assert graph_equal(back, node)
def test_prim_type():
x = tvm.ir.PrimType("int32")
assert isinstance(x, tvm.ir.PrimType)
assert x.dtype == "int32"
def test_tensor_type_bad_constructor():
try:
x = tvm.ir.TensorType("xx", "xx")
except tvm.error.TVMError:
pass
def test_tensor_type():
shape = tvm.runtime.convert([1, 2, 3])
dtype = 'float32'
tt = tvm.ir.TensorType(shape, dtype)
assert tt.dtype == dtype
assert tt.shape == shape
assert tt.span == None
str(tt)
check_json_roundtrip(tt)
def test_type_param():
tp = tvm.ir.TypeVar('name', tvm.ir.TypeKind.Type)
assert tp.kind == tvm.ir.TypeKind.Type
# assert tp.span # TODO allow us to set span
str(tp)
check_json_roundtrip(tp)
def test_func_type():
type_params = tvm.runtime.convert([])
type_constraints = tvm.runtime.convert([]) # TODO: fill me in
arg_types = tvm.runtime.convert([])
ret_type = tvm.ir.TensorType((1, 2, 3), 'float32')
tf = tvm.ir.FuncType(arg_types, ret_type, type_params, type_constraints)
assert tf.type_params == type_params
assert tf.type_constraints == type_constraints
assert tf.arg_types == arg_types
assert tf.ret_type == ret_type
assert tf.span == None
# TODO make sure we can set span
str(tf)
check_json_roundtrip(tf)
def test_tuple_type():
tp = tvm.ir.TypeVar('tp', tvm.ir.TypeKind.Type)
tf = tvm.ir.FuncType([], None, [], [])
tt = tvm.ir.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
fields = tvm.runtime.convert([tp, tf, tt])
tup_ty = tvm.ir.TupleType(fields)
assert tup_ty.fields == fields
str(tup_ty)
check_json_roundtrip(tup_ty)
def test_type_relation():
tp = tvm.ir.TypeVar('tp', tvm.ir.TypeKind.Type)
tf = tvm.ir.FuncType([], None, [], [])
tt = tvm.ir.TensorType(
tvm.runtime.convert([1, 2, 3]), 'float32')
args = tvm.runtime.convert([tp, tf, tt])
num_inputs = 2
func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
tr = tvm.ir.TypeRelation(func, args, num_inputs, attrs)
assert tr.args == args
assert tr.num_inputs == num_inputs
str(tr)
check_json_roundtrip(tr)
if __name__ == "__main__":
test_tensor_type_bad_constructor()
test_tensor_type()
test_type_param()
test_func_type()
test_tuple_type()
test_type_relation()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment