Commit 3f7cce3b by Tianqi Chen Committed by GitHub

[SCHEDULE] Fix schedule for big array (#1340)

parent 2f77a127
......@@ -2,15 +2,7 @@
"""Schedule for cudnn and miopen extern op"""
import tvm
from .. import generic
def _schedule_output(op, sch):
x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
return sch
from .injective import _schedule_injective
@generic.schedule_extern.register(["cuda", "gpu"])
......@@ -36,5 +28,5 @@ def schedule_extern(outs):
for out in outs:
if isinstance(out.op, tvm.tensor.ExternOp):
continue
_schedule_output(out.op, s)
_schedule_injective(out.op, s)
return s
# pylint: disable=invalid-name, unused-variable,
"""Schedule for composition of injective operator"""
import tvm
from .. import generic
from .. import generic, util
def _schedule_injective(op, sch):
x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = sch[x].split(fused, factor=num_thread)
max_block = 256
try:
const_size = util.get_const_int(util.prod(x.shape))
max_block = 256
need_block_split = const_size > max_block * num_thread
except ValueError:
need_block_split = False
if need_block_split:
xo, xi = sch[x].split(fused, factor=num_thread * max_block)
bx, tx = sch[x].split(xi, factor=num_thread)
sch[x].reorder(bx, tx, xo)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
else:
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
return sch
......
......@@ -2,6 +2,28 @@
from __future__ import absolute_import as _abs
import tvm
def prod(x):
"""Get the product of every items in the tuple.
Parameters
----------
x: tuple
Input tuple
Returns
-------
value : Expr
The result value
"""
if not x:
return tvm.const(1, "int32")
res = x[0]
for i in range(1, len(x)):
res = res * x[i]
return res
def get_const_int(expr):
"""Verifies expr is integer and get the constant value.
......
......@@ -71,6 +71,10 @@ def verify_prelu(x, w):
def test_relu():
verify_relu(10, 128)
def test_schedule_big_array():
verify_relu(1024 * 100 , 512)
def test_leaky_relu():
verify_leaky_relu(100, 0.1)
......@@ -78,6 +82,7 @@ def test_prelu():
verify_prelu((1, 3, 2, 2), (3,))
if __name__ == "__main__":
test_schedule_big_array()
test_relu()
test_leaky_relu()
test_prelu()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment