Commit cf3e7865 by Alexander Pivovarov Committed by masahi

Remove run_infer_type duplicates (#4766)

parent 4dbe4d98
......@@ -20,15 +20,10 @@ import tvm
import scipy
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import ctx_list
from tvm.relay.testing import ctx_list, run_infer_type
import topi.testing
from tvm.contrib.nvcc import have_fp16
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def sigmoid(x):
one = np.ones_like(x)
......
......@@ -21,15 +21,10 @@ import tvm
import topi.testing
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import ctx_list
from tvm.relay.testing import ctx_list, run_infer_type
import topi
import topi.testing
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def test_checkpoint():
dtype = "float32"
......
......@@ -21,13 +21,8 @@ import pytest
import tvm
from tvm import relay
from tvm.relay import create_executor, transform
from tvm.relay.testing import ctx_list, check_grad
from tvm.relay.testing import ctx_list, check_grad, run_infer_type
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def test_zeros_ones():
for op, ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]:
......
......@@ -18,14 +18,9 @@ import tvm
import numpy as np
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import ctx_list
from tvm.relay.testing import ctx_list, run_infer_type
import topi.testing
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def test_binary_op():
def check_binary_op(opfunc, ref):
......
......@@ -21,14 +21,9 @@ import numpy as np
import tvm
from tvm import relay
from tvm.relay import transform
from tvm.relay.testing import ctx_list
from tvm.relay.testing import ctx_list, run_infer_type
import topi.testing
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def test_resize_infer_type():
n, c, h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w")
......
......@@ -24,14 +24,7 @@ from tvm.relay import ExprFunctor
from tvm.relay import Function, Call
from tvm.relay import analysis
from tvm.relay import transform as _transform
from tvm.relay.testing import ctx_list
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = _transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
from tvm.relay.testing import ctx_list, run_infer_type
def get_var_func():
......
......@@ -21,15 +21,11 @@ from tvm.relay.analysis import alpha_equal, detect_feature
from tvm.relay.transform import to_cps, un_cps
from tvm.relay.feature import Feature
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, run_opt_pass
from tvm.relay.testing import add_nat_definitions, make_nat_expr, rand, run_infer_type, run_opt_pass
from tvm.relay import create_executor
from tvm.relay import Function, transform
def rand(dtype='float32', *shape):
return tvm.nd.array(np.random.rand(*shape).astype(dtype))
def test_id():
x = relay.var("x", shape=[])
id = run_infer_type(relay.Function([x], x))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment