Commit 3f7cce3b by Tianqi Chen Committed by GitHub

[SCHEDULE] Fix schedule for big array (#1340)

parent 2f77a127
...@@ -2,15 +2,7 @@ ...@@ -2,15 +2,7 @@
"""Schedule for cudnn and miopen extern op""" """Schedule for cudnn and miopen extern op"""
import tvm import tvm
from .. import generic from .. import generic
from .injective import _schedule_injective
def _schedule_output(op, sch):
x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
return sch
@generic.schedule_extern.register(["cuda", "gpu"]) @generic.schedule_extern.register(["cuda", "gpu"])
...@@ -36,5 +28,5 @@ def schedule_extern(outs): ...@@ -36,5 +28,5 @@ def schedule_extern(outs):
for out in outs: for out in outs:
if isinstance(out.op, tvm.tensor.ExternOp): if isinstance(out.op, tvm.tensor.ExternOp):
continue continue
_schedule_output(out.op, s) _schedule_injective(out.op, s)
return s return s
# pylint: disable=invalid-name, unused-variable, # pylint: disable=invalid-name, unused-variable,
"""Schedule for composition of injective operator""" """Schedule for composition of injective operator"""
import tvm import tvm
from .. import generic from .. import generic, util
def _schedule_injective(op, sch): def _schedule_injective(op, sch):
x = op.output(0) x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis) fused = sch[x].fuse(*sch[x].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = sch[x].split(fused, factor=num_thread) max_block = 256
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x")) try:
const_size = util.get_const_int(util.prod(x.shape))
max_block = 256
need_block_split = const_size > max_block * num_thread
except ValueError:
need_block_split = False
if need_block_split:
xo, xi = sch[x].split(fused, factor=num_thread * max_block)
bx, tx = sch[x].split(xi, factor=num_thread)
sch[x].reorder(bx, tx, xo)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
else:
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
return sch return sch
......
...@@ -2,6 +2,28 @@ ...@@ -2,6 +2,28 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
def prod(x):
"""Get the product of every items in the tuple.
Parameters
----------
x: tuple
Input tuple
Returns
-------
value : Expr
The result value
"""
if not x:
return tvm.const(1, "int32")
res = x[0]
for i in range(1, len(x)):
res = res * x[i]
return res
def get_const_int(expr): def get_const_int(expr):
"""Verifies expr is integer and get the constant value. """Verifies expr is integer and get the constant value.
......
...@@ -71,6 +71,10 @@ def verify_prelu(x, w): ...@@ -71,6 +71,10 @@ def verify_prelu(x, w):
def test_relu(): def test_relu():
verify_relu(10, 128) verify_relu(10, 128)
def test_schedule_big_array():
verify_relu(1024 * 100 , 512)
def test_leaky_relu(): def test_leaky_relu():
verify_leaky_relu(100, 0.1) verify_leaky_relu(100, 0.1)
...@@ -78,6 +82,7 @@ def test_prelu(): ...@@ -78,6 +82,7 @@ def test_prelu():
verify_prelu((1, 3, 2, 2), (3,)) verify_prelu((1, 3, 2, 2), (3,))
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_big_array()
test_relu() test_relu()
test_leaky_relu() test_leaky_relu()
test_prelu() test_prelu()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment