Commit 09f0ac33 by Logan Weber Committed by Jared Roesch

[Relay] Add Python type functor and tests (#4209)

* Add Python type functor and tests

* Lint roller
parent 2386e74b
...@@ -23,6 +23,7 @@ from ..api import register_func ...@@ -23,6 +23,7 @@ from ..api import register_func
from . import base from . import base
from . import ty from . import ty
from . import expr from . import expr
from . import type_functor
from . import expr_functor from . import expr_functor
from . import module from . import module
from . import adt from . import adt
...@@ -118,6 +119,11 @@ module_pass = transform.module_pass ...@@ -118,6 +119,11 @@ module_pass = transform.module_pass
function_pass = transform.function_pass function_pass = transform.function_pass
alpha_equal = analysis.alpha_equal alpha_equal = analysis.alpha_equal
# TypeFunctor
TypeFunctor = type_functor.TypeFunctor
TypeVisitor = type_functor.TypeVisitor
TypeMutator = type_functor.TypeMutator
# ExprFunctor # ExprFunctor
ExprFunctor = expr_functor.ExprFunctor ExprFunctor = expr_functor.ExprFunctor
ExprVisitor = expr_functor.ExprVisitor ExprVisitor = expr_functor.ExprVisitor
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The type functor of Relay."""
from .ty import (TypeVar, IncompleteType, TensorType, FuncType,
TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall)
from .adt import TypeData
class TypeFunctor:
"""
An abstract visitor defined over Type.
Defines the default dispatch over types.
"""
def __init__(self):
# TODO(weberlo): make type vars hashable, so we can memoize
pass
# pylint: disable=no-else-return
def visit(self, typ):
"""Apply the visitor to a type."""
if isinstance(typ, TypeVar):
return self.visit_type_var(typ)
elif isinstance(typ, IncompleteType):
return self.visit_incomplete_type(typ)
elif isinstance(typ, TensorType):
return self.visit_tensor_type(typ)
elif isinstance(typ, FuncType):
return self.visit_func_type(typ)
elif isinstance(typ, TupleType):
return self.visit_tuple_type(typ)
elif isinstance(typ, TypeRelation):
return self.visit_type_relation(typ)
elif isinstance(typ, RefType):
return self.visit_ref_type(typ)
elif isinstance(typ, GlobalTypeVar):
return self.visit_global_type_var(typ)
elif isinstance(typ, TypeCall):
return self.visit_type_call(typ)
elif isinstance(typ, TypeData):
return self.visit_type_data(typ)
else:
raise Exception('unhandled case: {0}'.format(type(typ)))
def visit_type_var(self, _):
raise NotImplementedError()
def visit_incomplete_type(self, _):
raise NotImplementedError()
def visit_tensor_type(self, _):
raise NotImplementedError()
def visit_func_type(self, _):
raise NotImplementedError()
def visit_tuple_type(self, _):
raise NotImplementedError()
def visit_type_relation(self, _):
raise NotImplementedError()
def visit_ref_type(self, _):
raise NotImplementedError()
def visit_global_type_var(self, _):
raise NotImplementedError()
def visit_type_call(self, _):
raise NotImplementedError()
def visit_type_data(self, _):
raise NotImplementedError()
class TypeVisitor(TypeFunctor):
"""
A visitor over Type.
The default behavior recursively traverses the AST.
"""
def visit_type_var(self, tv):
pass
def visit_incomplete_type(self, it):
pass
def visit_tensor_type(self, tt):
pass
def visit_func_type(self, ft):
for arg_type in ft.arg_types:
self.visit(arg_type)
self.visit(ft.ret_type)
for type_param in getattr(ft, 'type_params', []):
self.visit(type_param)
for type_constraint in getattr(ft, 'type_constraints', []):
self.visit(type_constraint)
def visit_tuple_type(self, tt):
for field in tt.fields:
self.visit(field)
def visit_type_relation(self, tr):
for arg in tr.args:
self.visit(arg)
def visit_ref_type(self, rt):
self.visit(rt.value)
def visit_global_type_var(self, gtv):
pass
def visit_type_call(self, tc):
self.visit(tc.func)
for arg in tc.args:
self.visit(arg)
def visit_type_data(self, td):
self.visit(td.header)
for type_var in td.type_vars:
self.visit(type_var)
class TypeMutator(TypeFunctor):
"""
A functional visitor over Type.
The default behavior recursively traverses the AST
and reconstructs the AST.
"""
def visit_type_var(self, tv):
return TypeVar(tv.var.name, tv.kind)
def visit_incomplete_type(self, it):
return IncompleteType(it.kind)
def visit_tensor_type(self, tt):
return TensorType(tt.shape, tt.dtype)
def visit_func_type(self, ft):
new_arg_types = [self.visit(arg_type) for arg_type in ft.arg_types]
new_ret_type = self.visit(ft.ret_type)
new_type_params = [
self.visit(type_param)
for type_param in getattr(ft, 'type_params', [])]
new_type_constraints = [
self.visit(type_constraint)
for type_constraint in getattr(ft, 'type_constraints', [])]
return FuncType(
new_arg_types,
new_ret_type,
new_type_params,
new_type_constraints)
def visit_tuple_type(self, tt):
return TupleType([self.visit(field) for field in tt.fields])
def visit_type_relation(self, tr):
return TypeRelation(
tr.func,
[self.visit(arg) for arg in tr.args],
tr.num_inputs,
tr.attrs)
def visit_ref_type(self, rt):
return RefType(self.visit(rt.value))
def visit_global_type_var(self, gtv):
return GlobalTypeVar(gtv.var.name, gtv.kind)
def visit_type_call(self, tc):
return TypeCall(
self.visit(tc.func),
[self.visit(arg) for arg in tc.args])
def visit_type_data(self, td):
return TypeData(
self.visit(td.header),
[self.visit(type_var) for type_var in td.type_vars],
td.constructors)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
from tvm.relay import TypeFunctor, TypeMutator, TypeVisitor
from tvm.relay.analysis import assert_graph_equal
from tvm.relay.ty import (TypeVar, IncompleteType, TensorType, FuncType,
TupleType, TypeRelation, RefType, GlobalTypeVar, TypeCall)
from tvm.relay.adt import TypeData
def check_visit(typ):
try:
ef = TypeFunctor()
ef.visit(typ)
assert False
except NotImplementedError:
pass
ev = TypeVisitor()
ev.visit(typ)
assert_graph_equal(TypeMutator().visit(typ), typ)
def test_type_var():
tv = TypeVar('a')
check_visit(tv)
def test_incomplete_type():
it = IncompleteType()
check_visit(it)
def test_tensor_type():
tt = TensorType([])
check_visit(tt)
def test_func_type():
tv = TypeVar('tv')
tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32')
ft = FuncType([tt], tt, type_params=[tv])
check_visit(ft)
def test_tuple_type():
tt = TupleType([TupleType([])])
check_visit(tt)
def test_type_relation():
func = tvm.get_env_func('tvm.relay.type_relation.Broadcast')
attrs = tvm.make.node('attrs.TestAttrs', name='attr', padding=(3,4))
tp = TypeVar('tp')
tf = FuncType([], TupleType([]), [], [])
tt = TensorType([1, 2, 3], 'float32')
tr = TypeRelation(func, [tp, tf, tt], 2, attrs)
check_visit(tr)
def test_ref_type():
rt = RefType(TupleType([]))
check_visit(rt)
def test_global_type_var():
gtv = GlobalTypeVar('gtv')
check_visit(gtv)
def test_type_call():
tc = TypeCall(GlobalTypeVar('tf'), [TupleType([])])
check_visit(tc)
def test_type_data():
td = TypeData(GlobalTypeVar('td'), [TypeVar('tv')], [])
check_visit(td)
if __name__ == "__main__":
test_type_var()
test_incomplete_type()
test_tensor_type()
test_func_type()
test_tuple_type()
test_type_relation()
test_ref_type()
test_global_type_var()
test_type_call()
test_type_data()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment