Commit d76712d1 by Tianqi Chen Committed by GitHub

[TOPI] Move topi.nn.util to topi.util (#319)

* [TOPI] Move topi.nn.util to topi.util

* update the path
parent f08de2b6
# pylint: disable=invalid-name
"""Schedule for depthwise_conv2d with auto fusion"""
import tvm
from ..nn.util import get_const_tuple
from ..util import get_const_tuple
def schedule_depthwise_conv2d_map(op):
"""Schedule for depthwise_conv2d map ops.
......
......@@ -3,7 +3,7 @@
from __future__ import absolute_import as _abs
import tvm
import numpy as np
from .util import get_const_tuple
from ..util import get_const_tuple
@tvm.tag_scope(tag="conv2d_hwcn")
......
......@@ -2,12 +2,32 @@
from __future__ import absolute_import as _abs
import tvm
def get_const_int(expr):
"""Verifies expr is integer and get the constant value.
Parameters
----------
expr :
The input expression.
Returns
-------
out_tuple : tuple of int
The output.
"""
if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
expr = tvm.ir_pass.Simplfy(expr)
if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
raise ValueError("Expect value to be constant int")
return expr.value
def get_const_tuple(in_tuple):
"""Verifies input tuple is IntImm, returns tuple of int.
Parameters
----------
in_tuple : tuple of tvm.expr.IntImm
in_tuple : tuple of Expr
The input.
Returns
......@@ -17,7 +37,7 @@ def get_const_tuple(in_tuple):
"""
out_tuple = ()
for elem in in_tuple:
if not isinstance(elem, tvm.expr.IntImm):
raise ValueError("Element of input tuple should be IntImm")
if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)):
raise ValueError("Element of input tuple should be const int")
out_tuple = out_tuple + (elem.value, )
return out_tuple
import tvm
import topi
from topi import util
def test_util():
x = tvm.const(100)
assert util.get_const_int(x) == 100
assert util.get_const_tuple((x, x)) == (100, 100)
def test_ewise():
m = tvm.var('m')
......@@ -19,4 +27,5 @@ def test_ewise():
if __name__ == "__main__":
test_util()
test_ewise()
......@@ -3,7 +3,7 @@ import os
import numpy as np
import tvm
import topi
from topi.nn.util import get_const_tuple
from topi.util import get_const_tuple
def verify_conv2d_hwcn_map(batch, in_channel, in_size, num_filter, kernel, stride, padding):
......
......@@ -2,7 +2,7 @@ import tvm
import topi
import numpy as np
from scipy import signal
from topi.nn.util import get_const_tuple
from topi.util import get_const_tuple
from topi.cuda.depthwise_conv2d_map import schedule_depthwise_conv2d_map
def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment