Unverified Commit 6027412b by Tianqi Chen Committed by GitHub

[IR] Update the type_keys to reflect the code-org (#5074)

parent 7c5ff508
......@@ -196,7 +196,7 @@ class GlobalVarNode : public RelayExprNode {
v->Visit("_checked_type_", &checked_type_);
}
static constexpr const char* _type_key = "relay.GlobalVar";
static constexpr const char* _type_key = "GlobalVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelayExprNode);
};
......
......@@ -226,7 +226,7 @@ class IRModuleNode : public Object {
*/
TVM_DLL std::unordered_set<std::string> Imports() const;
static constexpr const char* _type_key = "relay.Module";
static constexpr const char* _type_key = "IRModule";
TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object);
private:
......
......@@ -44,7 +44,7 @@ class SourceNameNode : public Object {
// override attr visitor
void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
static constexpr const char* _type_key = "relay.SourceName";
static constexpr const char* _type_key = "SourceName";
TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object);
};
......@@ -89,7 +89,7 @@ class SpanNode : public Object {
TVM_DLL static Span make(SourceName source, int lineno, int col_offset);
static constexpr const char* _type_key = "relay.Span";
static constexpr const char* _type_key = "Span";
TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
};
......
......@@ -110,7 +110,7 @@ class PassContextNode : public Object {
v->Visit("disabled_pass", &disabled_pass);
}
static constexpr const char* _type_key = "relay.PassContext";
static constexpr const char* _type_key = "transform.PassContext";
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
};
......@@ -206,7 +206,7 @@ class PassInfoNode : public Object {
v->Visit("required", &required);
}
static constexpr const char* _type_key = "relay.PassInfo";
static constexpr const char* _type_key = "transform.PassInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object);
};
......@@ -265,7 +265,7 @@ class PassNode : public Object {
void VisitAttrs(AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.Pass";
static constexpr const char* _type_key = "transform.Pass";
TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object);
};
......
......@@ -78,7 +78,7 @@ class TypeNode : public Object {
*/
mutable Span span;
static constexpr const char* _type_key = "relay.Type";
static constexpr const char* _type_key = "Type";
TVM_DECLARE_BASE_OBJECT_INFO(TypeNode, Object);
};
......@@ -110,7 +110,7 @@ class PrimTypeNode : public TypeNode {
v->Visit("dtype", &dtype);
}
static constexpr const char* _type_key = "relay.PrimType";
static constexpr const char* _type_key = "PrimType";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};
......@@ -175,7 +175,7 @@ class TypeVarNode : public TypeNode {
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.TypeVar";
static constexpr const char* _type_key = "TypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeVarNode, TypeNode);
};
......@@ -215,7 +215,7 @@ class GlobalTypeVarNode : public TypeNode {
v->Visit("kind", &kind);
}
static constexpr const char* _type_key = "relay.GlobalTypeVar";
static constexpr const char* _type_key = "GlobalTypeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(GlobalTypeVarNode, TypeNode);
};
......@@ -251,7 +251,7 @@ class TupleTypeNode : public TypeNode {
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.TupleType";
static constexpr const char* _type_key = "TupleType";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode);
};
......@@ -289,7 +289,7 @@ inline Type VoidType() {
*/
class TypeConstraintNode : public TypeNode {
public:
static constexpr const char* _type_key = "relay.TypeConstraint";
static constexpr const char* _type_key = "TypeConstraint";
TVM_DECLARE_BASE_OBJECT_INFO(TypeConstraintNode, TypeNode);
};
......@@ -334,7 +334,7 @@ class FuncTypeNode : public TypeNode {
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.FuncType";
static constexpr const char* _type_key = "FuncType";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode);
};
......@@ -380,7 +380,7 @@ class IncompleteTypeNode : public TypeNode {
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.IncompleteType";
static constexpr const char* _type_key = "IncompleteType";
TVM_DECLARE_FINAL_OBJECT_INFO(IncompleteTypeNode, TypeNode);
};
......@@ -417,6 +417,8 @@ class RelayRefTypeNode : public TypeNode {
v->Visit("span", &span);
}
// Keep the relay prefix in the type as this type is specific
// to the relay itself.
static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_FINAL_OBJECT_INFO(RelayRefTypeNode, TypeNode);
};
......
......@@ -50,7 +50,7 @@ class TypeCallNode : public TypeNode {
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.TypeCall";
static constexpr const char* _type_key = "TypeCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeCallNode, TypeNode);
};
......@@ -119,7 +119,7 @@ class TypeReporterNode : public Object {
// solver is not serializable.
void VisitAttrs(AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.TypeReporter";
static constexpr const char* _type_key = "TypeReporter";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object);
};
......@@ -195,7 +195,7 @@ class TypeRelationNode : public TypeConstraintNode {
v->Visit("span", &span);
}
static constexpr const char* _type_key = "relay.TypeRelation";
static constexpr const char* _type_key = "TypeRelation";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
};
......
......@@ -17,7 +17,7 @@
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
from .type import Type, TypeKind, TypeVar, GlobalTypeVar, TupleType
from .type import Type, TypeKind, PrimType, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
......
......@@ -56,7 +56,7 @@ class Node(Object):
return _ffi_api.PrettyPrint(self)
@tvm._ffi.register_object("relay.SourceName")
@tvm._ffi.register_object("SourceName")
class SourceName(Object):
"""A identifier for a source location.
......@@ -69,7 +69,7 @@ class SourceName(Object):
self.__init_handle_by_constructor__(_ffi_api.SourceName, name)
@tvm._ffi.register_object("relay.Span")
@tvm._ffi.register_object("Span")
class Span(Object):
"""Specifies a location in a source program.
......
......@@ -51,7 +51,7 @@ class RelayExpr(BaseExpr):
return ret
@tvm._ffi.register_object("relay.GlobalVar")
@tvm._ffi.register_object("GlobalVar")
class GlobalVar(RelayExpr):
"""A global variable in the IR.
......
......@@ -62,11 +62,35 @@ def create_updater_06_to_07():
# set vindex to null
nodes[vindex]["type_key"] = ""
del item["attrs"]["var"]
assert item["type_key"].startswith("relay.")
item["type_key"] = item["type_key"][len("relay."):]
return item
def _rename(new_name):
def _convert(item, _):
item["type_key"] = new_name
return item
return _convert
node_map = {
"relay.TypeVar": _ftype_var,
"relay.GlobalTypeVar": _ftype_var,
"relay.Type": _rename("Type"),
"relay.TupleType": _rename("TupleType"),
"relay.TypeConstraint": _rename("TypeConstraint"),
"relay.FuncType": _rename("FuncType"),
"relay.IncompleteType": _rename("IncompleteType"),
"relay.TypeRelation": _rename("TypeRelation"),
"relay.TypeCall": _rename("TypeCall"),
"relay.Module": _rename("IRModule"),
"relay.SourceName": _rename("SourceName"),
"relay.Span": _rename("Span"),
"relay.GlobalVar": _rename("GlobalVar"),
"relay.Pass": _rename("transform.Pass"),
"relay.PassInfo": _rename("transform.PassInfo"),
"relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequantial": _rename("transform.Sequantial"),
}
return create_updater(node_map, "0.6", "0.7")
......
......@@ -24,7 +24,7 @@ from . import type as _ty
from . import _ffi_api
@tvm._ffi.register_object("relay.Module")
@tvm._ffi.register_object("IRModule")
class IRModule(Node):
"""IRModule that holds functions and type definitions.
......
......@@ -27,7 +27,7 @@ from tvm.runtime import Object, ndarray as _nd
from . import _ffi_transform_api
@tvm._ffi.register_object("relay.PassInfo")
@tvm._ffi.register_object("transform.PassInfo")
class PassInfo(Object):
"""The class contains the meta data required by a pass. It is the
container of information needed by running an optimization or analysis.
......@@ -51,7 +51,7 @@ class PassInfo(Object):
_ffi_transform_api.PassInfo, opt_level, name, required)
@tvm._ffi.register_object("relay.PassContext")
@tvm._ffi.register_object("transform.PassContext")
class PassContext(Object):
"""The basis where a Relay optimization/analysis runs on.
Each pass context contains a number of auxiliary information that is used
......@@ -112,7 +112,7 @@ class PassContext(Object):
return _ffi_transform_api.GetCurrentPassContext()
@tvm._ffi.register_object("relay.Pass")
@tvm._ffi.register_object("transform.Pass")
class Pass(Object):
"""The base class of all passes. All methods here are just simple wrappers
that are implemented in the backend. They are defined for users to
......@@ -141,7 +141,7 @@ class Pass(Object):
return _ffi_transform_api.RunPass(self, mod)
@tvm._ffi.register_object("relay.ModulePass")
@tvm._ffi.register_object("transform.ModulePass")
class ModulePass(Pass):
"""A pass that works on tvm.IRModule. Users don't need to interact with
this class directly. Instead, a module pass should be created through
......@@ -152,7 +152,7 @@ class ModulePass(Pass):
"""
@tvm._ffi.register_object("relay.Sequential")
@tvm._ffi.register_object("transform.Sequential")
class Sequential(Pass):
"""A pass that works on a sequence of pass objects. Multiple passes can be
executed sequentially using this class.
......
......@@ -46,7 +46,20 @@ class TypeKind(IntEnum):
TypeData = 6
@tvm._ffi.register_object("relay.TypeVar")
class PrimType(Type):
"""Primitive data type in the low level IR
Parameters
----------
dtype : str
The runtime data type relates to the primtype.
"""
def __init__(self, dtype):
self.__init_handle_by_constructor__(
_ffi_api.PrimType, dtype)
@tvm._ffi.register_object("TypeVar")
class TypeVar(Type):
"""Type parameter in functions.
......@@ -85,7 +98,7 @@ class TypeVar(Type):
return TypeCall(self, args)
@tvm._ffi.register_object("relay.GlobalTypeVar")
@tvm._ffi.register_object("GlobalTypeVar")
class GlobalTypeVar(Type):
"""A global type variable that is used for defining new types or type aliases.
......@@ -120,7 +133,7 @@ class GlobalTypeVar(Type):
return TypeCall(self, args)
@tvm._ffi.register_object("relay.TupleType")
@tvm._ffi.register_object("TupleType")
class TupleType(Type):
"""The type of tuple values.
......@@ -135,12 +148,12 @@ class TupleType(Type):
_ffi_api.TupleType, fields)
@tvm._ffi.register_object("relay.TypeConstraint")
@tvm._ffi.register_object("TypeConstraint")
class TypeConstraint(Type):
"""Abstract class representing a type constraint."""
@tvm._ffi.register_object("relay.FuncType")
@tvm._ffi.register_object("FuncType")
class FuncType(Type):
"""Function type.
......@@ -179,7 +192,7 @@ class FuncType(Type):
_ffi_api.FuncType, arg_types, ret_type, type_params, type_constraints)
@tvm._ffi.register_object("relay.IncompleteType")
@tvm._ffi.register_object("IncompleteType")
class IncompleteType(Type):
"""Incomplete type during type inference.
......
......@@ -21,6 +21,7 @@ from .type import Type, TypeConstraint
from . import _ffi_api
@tvm._ffi.register_object("TypeCall")
class TypeCall(Type):
"""Type function application.
......@@ -41,7 +42,7 @@ class TypeCall(Type):
self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args)
@tvm._ffi.register_object("relay.TypeRelation")
@tvm._ffi.register_object("TypeRelation")
class TypeRelation(TypeConstraint):
"""User defined type relation, it is an input-output relation on types.
......
......@@ -132,7 +132,7 @@ class ModulePassNode : public PassNode {
*/
PassInfo Info() const override { return pass_info; }
static constexpr const char* _type_key = "relay.ModulePass";
static constexpr const char* _type_key = "transform.ModulePass";
TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode);
};
......@@ -206,7 +206,7 @@ class SequentialNode : public PassNode {
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
static constexpr const char* _type_key = "relay.Sequential";
static constexpr const char* _type_key = "transform.Sequential";
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
};
......
......@@ -30,13 +30,6 @@ def check_json_roundtrip(node):
assert graph_equal(back, node)
def test_bad_constructor():
try:
x = relay.ty.TensorType("xx", "xx")
except tvm.error.TVMError:
pass
# Span
def test_span():
span = relay.Span(None, 1, 1)
......@@ -55,71 +48,6 @@ def test_span():
assert back.lineno == span.lineno
assert back.col_offset == span.col_offset
# Types
def test_tensor_type():
shape = tvm.runtime.convert([1, 2, 3])
dtype = 'float32'
tt = relay.TensorType(shape, dtype)
assert tt.dtype == dtype
assert tt.shape == shape
assert tt.span == None
str(tt)
check_json_roundtrip(tt)
def test_type_param():
tp = relay.TypeVar('name', relay.TypeKind.Type)
assert tp.kind == relay.TypeKind.Type
# assert tp.span # TODO allow us to set span
str(tp)
check_json_roundtrip(tp)
def test_func_type():
type_params = tvm.runtime.convert([])
type_constraints = tvm.runtime.convert([]) # TODO: fill me in
arg_types = tvm.runtime.convert([])
ret_type = relay.TensorType((1, 2, 3), 'float32')
tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints)
assert tf.type_params == type_params
assert tf.type_constraints == type_constraints
assert tf.arg_types == arg_types
assert tf.ret_type == ret_type
assert tf.span == None
# TODO make sure we can set span
str(tf)
check_json_roundtrip(tf)
def test_tuple_type():
tp = relay.TypeVar('tp', relay.TypeKind.Type)
tf = relay.FuncType(tvm.runtime.convert([]), None, tvm.runtime.convert([]), tvm.runtime.convert([]))
tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
fields = tvm.runtime.convert([tp, tf, tt])
tup_ty = relay.TupleType(fields)
assert tup_ty.fields == fields
str(tup_ty)
check_json_roundtrip(tup_ty)
def test_type_relation():
tp = relay.TypeVar('tp', relay.TypeKind.Type)
tf = relay.FuncType(tvm.runtime.convert([]), None, tvm.runtime.convert([]), tvm.runtime.convert([]))
tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
args = tvm.runtime.convert([tp, tf, tt])
num_inputs = 2
func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
tr = relay.TypeRelation(func, args, num_inputs, attrs)
assert tr.args == args
assert tr.num_inputs == num_inputs
str(tr)
check_json_roundtrip(tr)
def test_constant():
arr = tvm.nd.array(10)
......@@ -280,13 +208,7 @@ def test_conv2d_attrs():
if __name__ == "__main__":
test_bad_constructor()
test_span()
test_tensor_type()
test_type_param()
test_func_type()
test_tuple_type()
test_type_relation()
test_constant()
test_tuple()
test_local_var()
......
......@@ -17,7 +17,6 @@
import tvm
from tvm import te
from tvm import relay
import json
def test_type_var():
......@@ -36,13 +35,81 @@ def test_type_var():
"b64ndarrays": [],
}
tvar = tvm.ir.load_json(json.dumps(data))
assert isinstance(tvar, relay.TypeVar)
assert isinstance(tvar, tvm.ir.TypeVar)
assert tvar.name_hint == "in0"
nodes[1]["type_key"] = "relay.GlobalTypeVar"
tvar = tvm.ir.load_json(json.dumps(data))
assert isinstance(tvar, relay.GlobalTypeVar)
assert isinstance(tvar, tvm.ir.GlobalTypeVar)
assert tvar.name_hint == "in0"
def test_incomplete_type():
nodes = [
{"type_key": ""},
{"type_key": "relay.IncompleteType",
"attrs": {"kind": "0", "span": "0"}}]
data = {
"root" : 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
tvar = tvm.ir.load_json(json.dumps(data))
assert isinstance(tvar, tvm.ir.IncompleteType)
def test_func_tuple_type():
nodes = [
{"type_key": ""},
{"type_key": "relay.FuncType",
"attrs": {
"arg_types": "2",
"ret_type": "3",
"span": "0",
"type_constraints": "6",
"type_params": "5"
}
},
{"type_key": "Array"},
{"type_key": "relay.TupleType",
"attrs": { "fields": "4", "span": "0" }},
{"type_key": "Array"},
{"type_key": "Array"},
{"type_key": "Array"}
]
data = {
"root" : 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
tvar = tvm.ir.load_json(json.dumps(data))
assert isinstance(tvar, tvm.ir.FuncType)
def test_global_var():
nodes = [
{"type_key": ""},
{"type_key": "relay.GlobalVar",
"attrs": {
"_checked_type_": "0",
"name_hint": "x",
"span": "0"
}
}
]
data = {
"root" : 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
tvar = tvm.ir.load_json(json.dumps(data))
assert isinstance(tvar, tvm.ir.GlobalVar)
if __name__ == "__main__":
test_type_var()
test_incomplete_type()
test_func_tuple_type()
test_global_var()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test type nodes in the IR"""
import tvm
def check_json_roundtrip(node):
from tvm.relay.analysis import graph_equal
json_str = tvm.ir.save_json(node)
back = tvm.ir.load_json(json_str)
assert graph_equal(back, node)
def test_prim_type():
x = tvm.ir.PrimType("int32")
assert isinstance(x, tvm.ir.PrimType)
assert x.dtype == "int32"
def test_tensor_type_bad_constructor():
try:
x = tvm.ir.TensorType("xx", "xx")
except tvm.error.TVMError:
pass
def test_tensor_type():
shape = tvm.runtime.convert([1, 2, 3])
dtype = 'float32'
tt = tvm.ir.TensorType(shape, dtype)
assert tt.dtype == dtype
assert tt.shape == shape
assert tt.span == None
str(tt)
check_json_roundtrip(tt)
def test_type_param():
tp = tvm.ir.TypeVar('name', tvm.ir.TypeKind.Type)
assert tp.kind == tvm.ir.TypeKind.Type
# assert tp.span # TODO allow us to set span
str(tp)
check_json_roundtrip(tp)
def test_func_type():
type_params = tvm.runtime.convert([])
type_constraints = tvm.runtime.convert([]) # TODO: fill me in
arg_types = tvm.runtime.convert([])
ret_type = tvm.ir.TensorType((1, 2, 3), 'float32')
tf = tvm.ir.FuncType(arg_types, ret_type, type_params, type_constraints)
assert tf.type_params == type_params
assert tf.type_constraints == type_constraints
assert tf.arg_types == arg_types
assert tf.ret_type == ret_type
assert tf.span == None
# TODO make sure we can set span
str(tf)
check_json_roundtrip(tf)
def test_tuple_type():
tp = tvm.ir.TypeVar('tp', tvm.ir.TypeKind.Type)
tf = tvm.ir.FuncType([], None, [], [])
tt = tvm.ir.TensorType(tvm.runtime.convert([1, 2, 3]), 'float32')
fields = tvm.runtime.convert([tp, tf, tt])
tup_ty = tvm.ir.TupleType(fields)
assert tup_ty.fields == fields
str(tup_ty)
check_json_roundtrip(tup_ty)
def test_type_relation():
tp = tvm.ir.TypeVar('tp', tvm.ir.TypeKind.Type)
tf = tvm.ir.FuncType([], None, [], [])
tt = tvm.ir.TensorType(
tvm.runtime.convert([1, 2, 3]), 'float32')
args = tvm.runtime.convert([tp, tf, tt])
num_inputs = 2
func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast")
attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4))
tr = tvm.ir.TypeRelation(func, args, num_inputs, attrs)
assert tr.args == args
assert tr.num_inputs == num_inputs
str(tr)
check_json_roundtrip(tr)
if __name__ == "__main__":
test_tensor_type_bad_constructor()
test_tensor_type()
test_type_param()
test_func_type()
test_tuple_type()
test_type_relation()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment