Commit b330d301 by Jon Soifer Committed by masahi

[TOPI][x86] Introduce schedule_injective_from_existing and unify external…

[TOPI][x86] Introduce schedule_injective_from_existing and unify external schedules for all targets (#3983)

* Fix extern schedule for x86

* Register x86::schedule_extern

* Fix

* Fix

* Replace extern.py with extern.h

* Introduce new generic function schedule_injective_from_existing

* Fix

* Fix

* Add back to C++

* Fix style

* Injective schedule calls local schedule_injective_from_existing

* Fix

* Remove target arg from schedule_injective_from_existing

* Fix docs

* Try to fix unit test

* Fix test

* Fix other tests

* Fix bug
parent d21f0ad5
......@@ -102,7 +102,11 @@ TEST(BuildModule, Heterogeneous) {
return copy[i] - C[i];
}, "elemwise_sub");
const runtime::PackedFunc* enter_target_scope_func = runtime::Registry::Get("_EnterTargetScope");
(*enter_target_scope_func)(target_cuda);
auto s1 = topi::cuda::schedule_injective(target_cuda, {elemwise_add});
(*enter_target_scope_func)(target_llvm);
auto s2 = create_schedule({elemwise_sub->op});
auto config = BuildConfig::Create();
......
......@@ -174,6 +174,7 @@ def test_simplex_data_transferring():
dev_tar = {"cuda": "cuda", "opencl": "opencl"}
for device, target in dev_tar.items():
with tvm.target.create(device):
check_device(device, target)
......@@ -394,6 +395,7 @@ def test_duplex_data_transferring():
dev_tar = {"cuda": "cuda", "opencl": "opencl"}
for device, target in dev_tar.items():
with tvm.target.create(device):
check_device(device, target)
if __name__ == "__main__":
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file cuda/extern.h
* \brief CUDA schedule for extern followed by injective operations
*/
#ifndef TOPI_CUDA_EXTERN_H_
#define TOPI_CUDA_EXTERN_H_
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/operation.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
namespace cuda {
/*!
* \brief Schedule a given operation representing one of the outputs of an
* external function which is followed by injective operations.
*
* \param target The target to generate a schedule for.
* \param op The operation representing the output followed by injective operations.
* \param sch The schedule to apply this scheduling to
*
* \return The schedule given by sch
*/
inline Schedule ScheduleOutputForExtern(Target target, Operation op, Schedule sch) {
auto x = op.output(0);
auto fused = detail::Fuse(sch[x], sch[x]->op.as<ComputeOpNode>()->axis);
auto num_thread = target->max_num_threads;
IterVar bx, tx;
sch[x].split(fused, num_thread, &bx, &tx);
sch[x].bind(bx, tvm::thread_axis(Range(), "blockIdx.x"));
sch[x].bind(tx, tvm::thread_axis(Range(), "threadIdx.x"));
return sch;
}
/*!
* \brief Schedule an extern op followed by injective operations.
* For example, cudnn kernel + bias add + relu
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the op.
*/
inline Schedule schedule_extern(const Target& target, Array<Tensor> outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
tvm::schedule::AutoInlineInjective(s);
for (auto out : outs) {
if (out->op->derived_from<ExternOpNode>()) {
continue;
}
ScheduleOutputForExtern(target, out->op, s);
}
return s;
}
} // namespace cuda
} // namespace topi
#endif // TOPI_CUDA_EXTERN_H_
......@@ -33,21 +33,24 @@ namespace topi {
using namespace tvm;
namespace cuda {
/*!
* \brief Schedule a given injective operation.
*
* \param target The target to generate a schedule for.
* \param op The operation representing the injective operation.
* \param s The schedule to apply this scheduling to
*/
inline void ScheduleInjectiveOp(const Target &target, Operation op, Schedule s) {
auto x = op.output(0);
auto fused = detail::Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
* \brief Updates an existing schedule for the given injective ops.
*
* \param sch The schedule to update.
* \param out The tensor representing the injective op.
*
* \return The updated schedule.
*/
inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
auto fused = detail::Fuse(sch[out], sch[out]->op.as<ComputeOpNode>()->axis);
auto target = Target::Current(false);
auto num_thread = target->max_num_threads;
IterVar bx, tx;
s[x].split(fused, num_thread, &bx, &tx);
s[x].bind(bx, thread_axis(Range(), "blockIdx.x"));
s[x].bind(tx, thread_axis(Range(), "threadIdx.x"));
sch[out].split(fused, num_thread, &bx, &tx);
sch[out].bind(bx, thread_axis(Range(), "blockIdx.x"));
sch[out].bind(tx, thread_axis(Range(), "threadIdx.x"));
return sch;
}
/*!
......@@ -66,7 +69,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou
auto s = create_schedule(out_ops);
tvm::schedule::AutoInlineInjective(s);
for (auto out : outs) {
ScheduleInjectiveOp(target, out->op, s);
schedule_injective_from_existing(s, out);
}
return s;
}
......
......@@ -28,6 +28,7 @@
#include "topi/detail/fuse.h"
#include "tvm/operation.h"
#include "tvm/build_module.h"
#include "injective.h"
namespace topi {
using namespace tvm;
......@@ -47,6 +48,15 @@ inline Schedule schedule_extern(const Target& target, Array<Tensor> outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
tvm::schedule::AutoInlineInjective(s);
for (auto out : outs) {
if (out->op->derived_from<ExternOpNode>()) {
continue;
}
tvm::GenericFunc::Get("schedule_injective_from_existing")(s, out);
}
return s;
}
......
......@@ -35,6 +35,19 @@ using namespace tvm;
namespace generic {
/*!
* \brief Updates an existing schedule for the given injective ops.
*
* \param sch The schedule to update.
* \param out The tensor representing the injective op.
*
* \return The updated schedule.
*/
inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
detail::Fuse(sch[out], sch[out]->op.as<ComputeOpNode>()->axis);
return sch;
}
/*!
* \brief Create a generic schedule for the given injective ops.
*
* \param target The target to generate a schedule for.
......@@ -50,7 +63,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou
auto s = create_schedule(out_ops);
tvm::schedule::AutoInlineInjective(s);
auto x = outs[0];
detail::Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
schedule_injective_from_existing(s, x);
return s;
}
......
......@@ -33,6 +33,28 @@ namespace topi {
using namespace tvm;
namespace x86 {
/*!
* \brief Updates an existing schedule for the given injective ops.
*
* \param sch The schedule to update.
* \param out The tensor representing the injective op.
*
* \return The updated schedule.
*/
inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
auto axis = sch[out]->op.as<ComputeOpNode>()->axis;
if (axis.size() == 4) {
auto n = axis[0];
auto c = axis[1];
auto fused = detail::Fuse(sch[out], { n, c }); // for nhwc layout, fuse n and h
sch[out].parallel(fused);
} else {
sch[out].parallel(axis[0]);
}
return sch;
}
/*!
* \brief Create an x86 schedule for the given injective ops.
*
......@@ -50,15 +72,7 @@ inline Schedule schedule_injective(const Target &target, const Array<Tensor>& ou
tvm::schedule::AutoInlineInjective(s);
auto x = outs[0];
auto axis = s[x]->op.as<ComputeOpNode>()->axis;
if (axis.size() == 4) {
auto n = axis[0];
auto c = axis[1];
auto fused = detail::Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h
s[x].parallel(fused);
} else {
s[x].parallel(axis[0]);
}
schedule_injective_from_existing(s, x);
return s;
}
......
......@@ -19,6 +19,32 @@
import tvm
from .. import generic
@generic.schedule_injective_from_existing.register(["arm_cpu"])
def schedule_injective_from_existing(sch, out):
"""Schedule for injective op from existing schedule.
Parameters
----------
sch: Schedule
The schedule to update.
out: Tensor
The tensor representing the injective op.
Returns
-------
sch: Schedule
The updated schedule.
"""
if len(sch[out].op.axis) >= 4:
fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2])
sch[out].parallel(fused)
elif len(sch[out].op.axis) >= 3:
fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1])
sch[out].parallel(fused)
elif len(sch[out].op.axis) >= 2:
sch[out].parallel(sch[out].op.axis[0])
return sch
@generic.schedule_injective.register(["arm_cpu"])
def schedule_injective(outs):
"""ARM CPU schedule for injective op.
......@@ -42,14 +68,7 @@ def schedule_injective(outs):
(io, ii) = s[x].split(list(s[x].op.axis)[-1], 8)
s[x].vectorize(ii)
tvm.schedule.AutoInlineInjective(s)
if len(s[x].op.axis) >= 4:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
s[x].parallel(fused)
elif len(s[x].op.axis) >= 3:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
s[x].parallel(fused)
elif len(s[x].op.axis) >= 2:
s[x].parallel(s[x].op.axis[0])
schedule_injective_from_existing(s, x)
return s
@generic.schedule_concatenate.register(["arm_cpu"])
......
......@@ -13,7 +13,6 @@ from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .dense import schedule_dense
from .pooling import schedule_pool, schedule_adaptive_pool
from .extern import schedule_extern
from .nn import schedule_lrn, schedule_l2_normalize
from .batch_matmul import schedule_batch_matmul
from .vision import *
......
......@@ -19,7 +19,7 @@
import tvm
from tvm import autotvm
from .injective import _schedule_injective
from .injective import schedule_injective_from_existing
from .tensor_intrin import dp4a
from ..nn.pad import pad
from ..nn.util import get_pad_tuple
......@@ -172,8 +172,8 @@ def schedule_conv2d_NCHWc_int8(cfg, s, output):
if isinstance(packed_kernel.op, tvm.tensor.ComputeOp) and\
packed_kernel.name == 'packed_kernel':
# data and kernel are not pre-computed, schedule layout transform here
_schedule_injective(packed_data.op, s)
_schedule_injective(packed_kernel.op, s)
schedule_injective_from_existing(s, packed_data)
schedule_injective_from_existing(s, packed_kernel)
if pad_data != packed_data:
s[pad_data].compute_inline()
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable,
"""Schedule for cudnn and miopen extern op"""
import tvm
from .. import generic
from .injective import _schedule_injective
@generic.schedule_extern.register(["cuda", "gpu"])
def schedule_extern(outs):
"""Schedule for an extern op followed by injective operations.
For example, cudnn kernel + bias add + relu.
Parameters
----------
outs: Array of Tensor
The computation graph description of extern plus injective ops in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
for out in outs:
if isinstance(out.op, tvm.tensor.ExternOp):
continue
_schedule_injective(out.op, s)
return s
......@@ -19,7 +19,7 @@
import tvm
from tvm import autotvm
from .injective import _schedule_injective
from .injective import schedule_injective_from_existing
from .tensor_intrin import dp4a
from ..nn.pad import pad
from ..nn.util import get_pad_tuple
......@@ -201,8 +201,8 @@ def schedule_group_conv2d_NCHWc_int8(cfg, s, output):
if isinstance(packed_kernel.op, tvm.tensor.ComputeOp) and\
packed_kernel.name == 'packed_kernel':
# data and kernel are not pre-computed, schedule layout transform here
_schedule_injective(packed_data.op, s)
_schedule_injective(packed_kernel.op, s)
schedule_injective_from_existing(s, packed_data)
schedule_injective_from_existing(s, packed_kernel)
if pad_data != packed_data:
s[pad_data].compute_inline()
......
......@@ -19,33 +19,46 @@
import tvm
from .. import generic, util
def _schedule_injective(op, sch):
x = op.output(0)
fused = sch[x].fuse(*sch[x].op.axis)
@generic.schedule_injective_from_existing.register(["cuda", "gpu"])
def schedule_injective_from_existing(sch, out):
"""Schedule for injective op from existing schedule.
Parameters
----------
sch: Schedule
The schedule to update.
out: Tensor
The tensor representing the injective op.
Returns
-------
sch: Schedule
The updated schedule.
"""
fused = sch[out].fuse(*sch[out].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
max_block = 256
try:
const_size = util.get_const_int(util.prod(x.shape))
const_size = util.get_const_int(util.prod(out.shape))
max_block = 256
need_block_split = const_size > max_block * num_thread
except ValueError:
need_block_split = False
if need_block_split:
xo, xi = sch[x].split(fused, factor=num_thread * max_block)
bx, tx = sch[x].split(xi, factor=num_thread)
sch[x].reorder(bx, tx, xo)
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
xo, xi = sch[out].split(fused, factor=num_thread * max_block)
bx, tx = sch[out].split(xi, factor=num_thread)
sch[out].reorder(bx, tx, xo)
sch[out].bind(bx, tvm.thread_axis("blockIdx.x"))
sch[out].bind(tx, tvm.thread_axis("threadIdx.x"))
else:
bx, tx = sch[x].split(fused, factor=num_thread)
sch[x].bind(tx, tvm.thread_axis("threadIdx.x"))
sch[x].bind(bx, tvm.thread_axis("blockIdx.x"))
bx, tx = sch[out].split(fused, factor=num_thread)
sch[out].bind(tx, tvm.thread_axis("threadIdx.x"))
sch[out].bind(bx, tvm.thread_axis("blockIdx.x"))
return sch
@generic.schedule_injective.register(["cuda", "gpu"])
def schedule_injective(outs):
"""Schedule for injective op.
......@@ -66,7 +79,7 @@ def schedule_injective(outs):
tvm.schedule.AutoInlineInjective(s)
for out in outs:
_schedule_injective(out.op, s)
schedule_injective_from_existing(s, out)
return s
schedule_elemwise = schedule_injective
......
......@@ -20,7 +20,7 @@ from __future__ import absolute_import as _abs
import tvm
from .. import tag
from .. import generic
from .injective import _schedule_injective
from .injective import schedule_injective_from_existing
def _schedule_reduce(op, sch, is_idx_reduce=False):
if is_idx_reduce:
......@@ -30,7 +30,7 @@ def _schedule_reduce(op, sch, is_idx_reduce=False):
data_out = op.output(0)
if not sch[data_out].op.reduce_axis:
return _schedule_injective(op, sch)
return schedule_injective_from_existing(sch, op.output(0))
if len(sch[data_out].op.axis) > 0:
all_reduce = False
......@@ -126,7 +126,7 @@ def schedule_reduce(outs):
"""Internal travserse function"""
if tag.is_broadcast(operator.tag):
if operator not in scheduled_ops:
_schedule_injective(operator, sch)
schedule_injective_from_existing(sch, operator.output(0))
for tensor in operator.input_tensors:
traverse_after_reduce(tensor.op)
elif operator.tag == 'comm_reduce':
......
......@@ -18,7 +18,7 @@
"""Schedule for softmax operator"""
import tvm
from .. import generic
from .injective import _schedule_injective
from .injective import schedule_injective_from_existing
@generic.schedule_softmax.register(["cuda", "gpu"])
def schedule_softmax(outs):
......@@ -58,7 +58,7 @@ def schedule_softmax(outs):
ops.append(exp.op)
for op in ops:
s = _schedule_injective(op, s)
s = schedule_injective_from_existing(s, op.output(0))
else:
num_thread = 64
block_x = tvm.thread_axis("blockIdx.x")
......
......@@ -42,10 +42,10 @@ def _schedule_sort(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
from .injective import _schedule_injective
from .injective import schedule_injective_from_existing
def traverse(op):
if tag.is_injective(op.tag):
_schedule_injective(op, s)
schedule_injective_from_existing(s, op.output(0))
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
......
......@@ -28,10 +28,10 @@ def _default_schedule(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
from .injective import _schedule_injective
from .injective import schedule_injective_from_existing
def traverse(op):
if tag.is_broadcast(op.tag) or op.tag in ['bbox_score', 'sorted_bbox']:
_schedule_injective(op, s)
schedule_injective_from_existing(s, op.output(0))
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
......
......@@ -19,6 +19,7 @@
from __future__ import absolute_import as _abs
import tvm
from .. import cpp
@tvm.target.generic_func
def schedule_extern(outs):
......@@ -35,8 +36,5 @@ def schedule_extern(outs):
sch: Schedule
The computation schedule for the op.
"""
target = tvm.target.current_target(allow_none=False)
if target.target_name != "llvm":
raise RuntimeError("schedule_extern not registered for '%s'" % target)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
return tvm.create_schedule([x.op for x in outs])
target = tvm.target.current_target()
return cpp.generic.schedule_extern(target, outs)
......@@ -20,6 +20,25 @@ from __future__ import absolute_import as _abs
import tvm
@tvm.target.override_native_generic_func("schedule_injective_from_existing")
def schedule_injective_from_existing(sch, out):
"""Schedule for injective op from existing schedule.
Parameters
----------
sch: Schedule
The schedule to update.
out: Tensor
The tensor representing the injective op.
Returns
-------
sch: Schedule
The updated schedule.
"""
sch[out].fuse(s[out].op.axis)
return sch
@tvm.target.override_native_generic_func("schedule_injective")
def schedule_injective(outs):
"""Schedule for injective op.
......@@ -42,7 +61,7 @@ def schedule_injective(outs):
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
schedule_injective_from_existing(s, x)
return s
@tvm.target.generic_func
......
......@@ -19,6 +19,27 @@
import tvm
from .. import generic
@generic.schedule_injective_from_existing.register(["hls"])
def schedule_injective_from_existing(sch, out):
"""Schedule for injective op from existing schedule.
Parameters
----------
sch: Schedule
The schedule to update.
out: Tensor
The tensor representing the injective op.
Returns
-------
sch: Schedule
The updated schedule.
"""
fused = sch[out].fuse(*sch[out].op.axis)
px, x = sch[out].split(fused, nparts=1)
sch[out].bind(px, tvm.thread_axis("pipeline"))
return sch
@generic.schedule_injective.register(["hls"])
def schedule_injective(outs):
"""Schedule for injective op.
......@@ -38,9 +59,7 @@ def schedule_injective(outs):
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
for out in outs:
fused = s[out].fuse(*s[out].op.axis)
px, x = s[out].split(fused, nparts=1)
s[out].bind(px, tvm.thread_axis("pipeline"))
schedule_injective_from_existing(s, out)
return s
schedule_elemwise = schedule_injective
......
......@@ -19,11 +19,24 @@
import tvm
from .. import generic
def _schedule_injective(op, sch):
x = op.output(0)
sch[x].opengl()
return sch
@generic.schedule_injective_from_existing.register(["opengl"])
def schedule_injective_from_existing(sch, out):
"""Schedule for injective op from existing schedule.
Parameters
----------
sch: Schedule
The schedule to update.
out: Tensor
The tensor representing the injective op.
Returns
-------
sch: Schedule
The updated schedule.
"""
sch[out].opengl()
return sch
@generic.schedule_injective.register(["opengl"])
def schedule_injective(outs):
......@@ -45,7 +58,7 @@ def schedule_injective(outs):
tvm.schedule.AutoInlineInjective(s)
for out in outs:
_schedule_injective(out.op, s)
schedule_injective_from_existing(s, out)
return s
schedule_elemwise = schedule_injective
......
......@@ -111,6 +111,10 @@ def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct")
def _schedule_dense(cfg, outs):
target = tvm.target.current_target()
if "cblas" in target.libs:
return generic.schedule_extern(outs)
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
......
......@@ -20,6 +20,32 @@ from __future__ import absolute_import as _abs
import tvm
from .. import generic
@generic.schedule_injective_from_existing.register(["cpu"])
def schedule_injective_from_existing(sch, out):
"""Schedule for injective op from existing schedule.
Parameters
----------
sch: Schedule
The schedule to update.
out: Tensor
The tensor representing the injective op.
Returns
-------
sch: Schedule
The updated schedule.
"""
if len(sch[out].op.axis) >= 5:
fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2])
sch[out].parallel(fused)
elif len(sch[out].op.axis) >= 3:
fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1])
sch[out].parallel(fused)
elif len(sch[out].op.axis) >= 1:
sch[out].parallel(sch[out].op.axis[0])
return sch
@generic.schedule_injective.register(["cpu"])
def schedule_injective(outs):
"""X86 schedule for injective op.
......@@ -39,14 +65,7 @@ def schedule_injective(outs):
x = outs[0]
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
if len(s[x].op.axis) >= 5:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
s[x].parallel(fused)
elif len(s[x].op.axis) >= 3:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
s[x].parallel(fused)
elif len(s[x].op.axis) >= 1:
s[x].parallel(s[x].op.axis[0])
schedule_injective_from_existing(s, x)
return s
@generic.schedule_concatenate.register(["cpu"])
......
......@@ -56,7 +56,6 @@
#include <topi/generic/injective.h>
#include <topi/cuda/dense.h>
#include <topi/cuda/extern.h>
#include <topi/cuda/injective.h>
#include <topi/cuda/pooling.h>
#include <topi/cuda/reduction.h>
......@@ -586,6 +585,11 @@ TVM_REGISTER_GLOBAL("topi.generic.schedule_injective")
*rv = topi::generic::schedule_injective(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::generic::schedule_injective_from_existing(args[0], args[1]);
});
/* x86 schedules */
TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......@@ -611,6 +615,11 @@ TVM_REGISTER_GLOBAL("topi.x86.schedule_injective")
*rv = topi::x86::schedule_injective(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::x86::schedule_injective_from_existing(args[0], args[1]);
});
/* ROCm schedules */
TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......@@ -643,14 +652,14 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense")
*rv = topi::cuda::schedule_dense(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_extern")
TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_extern(args[0], args[1]);
*rv = topi::cuda::schedule_injective(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective")
TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_injective(args[0], args[1]);
*rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool")
......@@ -752,6 +761,30 @@ TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binary_dense));
/*! \brief Builder function for instantiating schedules from existing schedules. */
using FTVMScheduleFromExistingBuilder = std::function<
tvm::Schedule(tvm::Schedule sch, const tvm::Tensor& out)>;
/*!
* \brief Helper function for registering generic functions matching the
* FTVMScheduleFromExistingBuilder signature. The schedule builder function is wrapped
* with a PackedFunc suitable for passing to a tvm::GenericFunc.
*
* \param builder The schedule builder to wrap.
*
* \return The wrapped schedule builder
*/
inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder) {
return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
*ret = builder(args[0], args[1]);
});
}
TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing)
.set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing))
.register_func({ "cpu" }, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing))
.register_func({ "cuda", "gpu" }, WrapScheduleFromExisting(topi::cuda::schedule_injective_from_existing));
/*! \brief Builder function for instantiating dense ops. */
using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
const tvm::Tensor& data,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment