__init__.py 1.65 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19
"""Utilities for testing and benchmarks"""
from __future__ import absolute_import as _abs

20 21 22
import tvm.relay as relay
from tvm.relay import transform

23 24
from . import mlp
from . import resnet
25
from . import dqn
26
from . import dcgan
eqy committed
27
from . import mobilenet
28
from . import lstm
29 30 31
from . import inception_v3
from . import squeezenet
from . import vgg
32
from . import densenet
Siju committed
33
from . import yolo_detection
34

35
from .config import ctx_list
36
from .init import create_workload
37
from .nat import add_nat_definitions, count, make_nat_value, make_nat_expr
38
from .py_converter import to_python, run_as_python
39 40 41 42 43 44


def run_opt_pass(expr, opt_pass):
    assert isinstance(opt_pass, transform.Pass)
    mod = relay.Module.from_expr(expr)
    mod = opt_pass(mod)
45
    entry = mod["main"]
46 47 48 49 50
    return entry if isinstance(expr, relay.Function) else entry.body


def run_infer_type(expr):
    return run_opt_pass(expr, transform.InferType())