Unverified Commit 9cb9a51f by Thomas Viehmann Committed by GitHub

rocm: fix dense_rocblas in strategy, topi (#5191)

parent 430cb899
...@@ -129,7 +129,7 @@ def dense_strategy_rocm(attrs, inputs, out_type, target): ...@@ -129,7 +129,7 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported." assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
strategy.add_implementation( strategy.add_implementation(
wrap_compute_dense(topi.rocm.dense_rocblas), wrap_compute_dense(topi.rocm.dense_rocblas),
wrap_topi_schedule(topi.rocm.dense_rocblas), wrap_topi_schedule(topi.rocm.schedule_dense_rocblas),
name="dense_rocblas.rocm", name="dense_rocblas.rocm",
plevel=15) plevel=15)
return strategy return strategy
...@@ -123,6 +123,8 @@ def dense_rocblas(cfg, data, weight, bias=None, out_dtype=None): ...@@ -123,6 +123,8 @@ def dense_rocblas(cfg, data, weight, bias=None, out_dtype=None):
output : tvm.te.Tensor output : tvm.te.Tensor
2-D with shape [batch, out_dim] 2-D with shape [batch, out_dim]
""" """
if out_dtype is None:
out_dtype = data.dtype
assert out_dtype == data.dtype, "Mixed precision not supported." assert out_dtype == data.dtype, "Mixed precision not supported."
matmul = rocblas.matmul(data, weight, False, True) matmul = rocblas.matmul(data, weight, False, True)
batch, in_dim = data.shape batch, in_dim = data.shape
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment