Commit 31ba0139 by masahi Committed by Wuwei Lin

[ROCm] Fix dense autotvm template registration (#3136)

* Fix rocm dense autotvm template

* suppres lint warning
parent 094fc680
......@@ -11,6 +11,7 @@ from .group_conv2d_nchw import schedule_conv2d_nchw_cuda
from .reduction import schedule_reduce
from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import schedule_dense
from .pooling import schedule_pool, schedule_global_pool
from .extern import schedule_extern
from .nn import schedule_lrn, schedule_l2_normalize
......
......@@ -14,18 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable
# pylint: disable=invalid-name, unused-variable, unused-argument
"""Schedule for dense operator"""
from __future__ import absolute_import as _abs
import tvm
from tvm import autotvm
from tvm.contrib import rocblas
import topi
from ..nn.dense import dense, dense_default
from .. import tag
from .. import generic
@dense.register("rocm")
def dense_rocm(data, weight, bias=None, out_dtype=None):
@autotvm.register_topi_compute(dense, "rocm", "direct")
def dense_rocm(cfg, data, weight, bias=None, out_dtype=None):
"""Dense operator for rocm backend.
Parameters
......@@ -67,8 +68,8 @@ def dense_rocm(data, weight, bias=None, out_dtype=None):
return dense_default(data, weight, bias, out_dtype)
@generic.schedule_dense.register(["rocm"])
def schedule_dense(outs):
@autotvm.register_topi_schedule(generic.schedule_dense, "rocm", "direct")
def schedule_dense(cfg, outs):
"""Schedule for dense operator.
Parameters
......@@ -85,4 +86,4 @@ def schedule_dense(outs):
target = tvm.target.current_target()
if target.target_name == "rocm" and "rocblas" in target.libs:
return generic.schedule_extern(outs)
return topi.cuda.schedule_dense(outs)
return topi.cuda.schedule_dense(cfg, outs)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment