Commit d76712d1 by Tianqi Chen Committed by GitHub

[TOPI] Move topi.nn.util to topi.util (#319)

* [TOPI] Move topi.nn.util to topi.util

* update the path
parent f08de2b6
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""Schedule for depthwise_conv2d with auto fusion""" """Schedule for depthwise_conv2d with auto fusion"""
import tvm import tvm
from ..nn.util import get_const_tuple from ..util import get_const_tuple
def schedule_depthwise_conv2d_map(op): def schedule_depthwise_conv2d_map(op):
"""Schedule for depthwise_conv2d map ops. """Schedule for depthwise_conv2d map ops.
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
import numpy as np import numpy as np
from .util import get_const_tuple from ..util import get_const_tuple
@tvm.tag_scope(tag="conv2d_hwcn") @tvm.tag_scope(tag="conv2d_hwcn")
......
...@@ -2,12 +2,32 @@ ...@@ -2,12 +2,32 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
def get_const_int(expr):
"""Verifies expr is integer and get the constant value.
Parameters
----------
expr :
The input expression.
Returns
-------
out_tuple : tuple of int
The output.
"""
if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
expr = tvm.ir_pass.Simplfy(expr)
if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
raise ValueError("Expect value to be constant int")
return expr.value
def get_const_tuple(in_tuple): def get_const_tuple(in_tuple):
"""Verifies input tuple is IntImm, returns tuple of int. """Verifies input tuple is IntImm, returns tuple of int.
Parameters Parameters
---------- ----------
in_tuple : tuple of tvm.expr.IntImm in_tuple : tuple of Expr
The input. The input.
Returns Returns
...@@ -17,7 +37,7 @@ def get_const_tuple(in_tuple): ...@@ -17,7 +37,7 @@ def get_const_tuple(in_tuple):
""" """
out_tuple = () out_tuple = ()
for elem in in_tuple: for elem in in_tuple:
if not isinstance(elem, tvm.expr.IntImm): if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)):
raise ValueError("Element of input tuple should be IntImm") raise ValueError("Element of input tuple should be const int")
out_tuple = out_tuple + (elem.value, ) out_tuple = out_tuple + (elem.value, )
return out_tuple return out_tuple
import tvm import tvm
import topi import topi
from topi import util
def test_util():
x = tvm.const(100)
assert util.get_const_int(x) == 100
assert util.get_const_tuple((x, x)) == (100, 100)
def test_ewise(): def test_ewise():
m = tvm.var('m') m = tvm.var('m')
...@@ -19,4 +27,5 @@ def test_ewise(): ...@@ -19,4 +27,5 @@ def test_ewise():
if __name__ == "__main__": if __name__ == "__main__":
test_util()
test_ewise() test_ewise()
...@@ -3,7 +3,7 @@ import os ...@@ -3,7 +3,7 @@ import os
import numpy as np import numpy as np
import tvm import tvm
import topi import topi
from topi.nn.util import get_const_tuple from topi.util import get_const_tuple
def verify_conv2d_hwcn_map(batch, in_channel, in_size, num_filter, kernel, stride, padding): def verify_conv2d_hwcn_map(batch, in_channel, in_size, num_filter, kernel, stride, padding):
......
...@@ -2,7 +2,7 @@ import tvm ...@@ -2,7 +2,7 @@ import tvm
import topi import topi
import numpy as np import numpy as np
from scipy import signal from scipy import signal
from topi.nn.util import get_const_tuple from topi.util import get_const_tuple
from topi.cuda.depthwise_conv2d_map import schedule_depthwise_conv2d_map from topi.cuda.depthwise_conv2d_map import schedule_depthwise_conv2d_map
def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding): def depthwise_conv2d_map_with_workload(batch, in_channel, in_height, channel_multiplier, filter_height, stride_h, padding):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment