Commit cf3e7865 by Alexander Pivovarov Committed by masahi

Remove run_infer_type duplicates (#4766)

parent 4dbe4d98
...@@ -20,15 +20,10 @@ import tvm ...@@ -20,15 +20,10 @@ import tvm
import scipy import scipy
from tvm import relay from tvm import relay
from tvm.relay import transform from tvm.relay import transform
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list, run_infer_type
import topi.testing import topi.testing
from tvm.contrib.nvcc import have_fp16 from tvm.contrib.nvcc import have_fp16
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def sigmoid(x): def sigmoid(x):
one = np.ones_like(x) one = np.ones_like(x)
......
...@@ -21,15 +21,10 @@ import tvm ...@@ -21,15 +21,10 @@ import tvm
import topi.testing import topi.testing
from tvm import relay from tvm import relay
from tvm.relay import transform from tvm.relay import transform
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list, run_infer_type
import topi import topi
import topi.testing import topi.testing
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def test_checkpoint(): def test_checkpoint():
dtype = "float32" dtype = "float32"
......
...@@ -21,13 +21,8 @@ import pytest ...@@ -21,13 +21,8 @@ import pytest
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay import create_executor, transform from tvm.relay import create_executor, transform
from tvm.relay.testing import ctx_list, check_grad from tvm.relay.testing import ctx_list, check_grad, run_infer_type
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def test_zeros_ones(): def test_zeros_ones():
for op, ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]: for op, ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]:
......
...@@ -18,14 +18,9 @@ import tvm ...@@ -18,14 +18,9 @@ import tvm
import numpy as np import numpy as np
from tvm import relay from tvm import relay
from tvm.relay import transform from tvm.relay import transform
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list, run_infer_type
import topi.testing import topi.testing
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def test_binary_op(): def test_binary_op():
def check_binary_op(opfunc, ref): def check_binary_op(opfunc, ref):
......
...@@ -21,14 +21,9 @@ import numpy as np ...@@ -21,14 +21,9 @@ import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay import transform from tvm.relay import transform
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list, run_infer_type
import topi.testing import topi.testing
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def test_resize_infer_type(): def test_resize_infer_type():
n, c, h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w") n, c, h, w = tvm.size_var("n"), tvm.size_var("c"), tvm.size_var("h"), tvm.size_var("w")
......
...@@ -24,14 +24,7 @@ from tvm.relay import ExprFunctor ...@@ -24,14 +24,7 @@ from tvm.relay import ExprFunctor
from tvm.relay import Function, Call from tvm.relay import Function, Call
from tvm.relay import analysis from tvm.relay import analysis
from tvm.relay import transform as _transform from tvm.relay import transform as _transform
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list, run_infer_type
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = _transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def get_var_func(): def get_var_func():
......
...@@ -21,15 +21,11 @@ from tvm.relay.analysis import alpha_equal, detect_feature ...@@ -21,15 +21,11 @@ from tvm.relay.analysis import alpha_equal, detect_feature
from tvm.relay.transform import to_cps, un_cps from tvm.relay.transform import to_cps, un_cps
from tvm.relay.feature import Feature from tvm.relay.feature import Feature
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type, run_opt_pass from tvm.relay.testing import add_nat_definitions, make_nat_expr, rand, run_infer_type, run_opt_pass
from tvm.relay import create_executor from tvm.relay import create_executor
from tvm.relay import Function, transform from tvm.relay import Function, transform
def rand(dtype='float32', *shape):
return tvm.nd.array(np.random.rand(*shape).astype(dtype))
def test_id(): def test_id():
x = relay.var("x", shape=[]) x = relay.var("x", shape=[])
id = run_infer_type(relay.Function([x], x)) id = run_infer_type(relay.Function([x], x))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment