Commit 31ba0139 by masahi Committed by Wuwei Lin

[ROCm] Fix dense autotvm template registration (#3136)

* Fix rocm dense autotvm template

* suppres lint warning
parent 094fc680
...@@ -11,6 +11,7 @@ from .group_conv2d_nchw import schedule_conv2d_nchw_cuda ...@@ -11,6 +11,7 @@ from .group_conv2d_nchw import schedule_conv2d_nchw_cuda
from .reduction import schedule_reduce from .reduction import schedule_reduce
from .softmax import schedule_softmax from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import schedule_dense
from .pooling import schedule_pool, schedule_global_pool from .pooling import schedule_pool, schedule_global_pool
from .extern import schedule_extern from .extern import schedule_extern
from .nn import schedule_lrn, schedule_l2_normalize from .nn import schedule_lrn, schedule_l2_normalize
......
...@@ -14,18 +14,19 @@ ...@@ -14,18 +14,19 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, unused-variable # pylint: disable=invalid-name, unused-variable, unused-argument
"""Schedule for dense operator""" """Schedule for dense operator"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from tvm import autotvm
from tvm.contrib import rocblas from tvm.contrib import rocblas
import topi import topi
from ..nn.dense import dense, dense_default from ..nn.dense import dense, dense_default
from .. import tag from .. import tag
from .. import generic from .. import generic
@dense.register("rocm") @autotvm.register_topi_compute(dense, "rocm", "direct")
def dense_rocm(data, weight, bias=None, out_dtype=None): def dense_rocm(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator for rocm backend. """Dense operator for rocm backend.
Parameters Parameters
...@@ -67,8 +68,8 @@ def dense_rocm(data, weight, bias=None, out_dtype=None): ...@@ -67,8 +68,8 @@ def dense_rocm(data, weight, bias=None, out_dtype=None):
return dense_default(data, weight, bias, out_dtype) return dense_default(data, weight, bias, out_dtype)
@generic.schedule_dense.register(["rocm"]) @autotvm.register_topi_schedule(generic.schedule_dense, "rocm", "direct")
def schedule_dense(outs): def schedule_dense(cfg, outs):
"""Schedule for dense operator. """Schedule for dense operator.
Parameters Parameters
...@@ -85,4 +86,4 @@ def schedule_dense(outs): ...@@ -85,4 +86,4 @@ def schedule_dense(outs):
target = tvm.target.current_target() target = tvm.target.current_target()
if target.target_name == "rocm" and "rocblas" in target.libs: if target.target_name == "rocm" and "rocblas" in target.libs:
return generic.schedule_extern(outs) return generic.schedule_extern(outs)
return topi.cuda.schedule_dense(outs) return topi.cuda.schedule_dense(cfg, outs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment