Commit 6292204e by alex-weaver Committed by Tianqi Chen

Implement C++ registry to back Python target.generic_func (#892)

parent 6588662f
......@@ -6,6 +6,7 @@
#include <tvm/runtime/threading_backend.h>
#include <dmlc/logging.h>
#include <thread>
#include <algorithm>
#if defined(__linux__)
#include <sched.h>
#endif
......
......@@ -31,7 +31,7 @@ TEST(BuildModule, Basic) {
auto target = target::llvm();
auto lowered = lower(s, args, "func", binds, config);
auto module = build(lowered, target, nullptr, config);
auto module = build(lowered, target, Target(), config);
}
......
......@@ -34,11 +34,16 @@ def test_target_dispatch():
with tvm.target.create("metal"):
assert mygeneric(1) == 3
try:
mygeneric(0)
raise RuntimeError("not reached")
except RuntimeError:
pass
assert tvm.target.current_target() == None
def test_target_string_parse():
target = tvm.target.create("cuda -libs=cublas,cudnn")
assert target.target_name == "cuda"
assert target.options == ['-libs=cublas,cudnn']
assert target.keys == ['cuda', 'gpu']
assert target.libs == ['cublas', 'cudnn']
if __name__ == "__main__":
test_target_dispatch()
test_target_string_parse()
......@@ -24,31 +24,30 @@ namespace cuda {
* \param target The target device
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim] (optional)
* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
*
* \return Tensor with shape [batch, out_dim]
*/
inline tvm::Tensor dense_cuda(const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
tvm::Tensor* bias) {
const tvm::Tensor& bias) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias != nullptr) {
CHECK_EQ((*bias)->shape.size(), 1) << "dense requires 1-D bias";
if (bias.defined()) {
CHECK_EQ(bias->shape.size(), 1) << "dense requires 1-D bias";
}
auto batch = data->shape[0];
auto in_dim = data->shape[1];
auto out_dim = weight->shape[0];
if (target.libs.count("cublas") > 0) {
if (target->libs().count("cublas")) {
auto mm = topi::contrib::cublas_matmul(data, weight, false, true);
if (bias != nullptr) {
auto bias_val = *bias;
if (bias.defined()) {
mm = tvm::compute({ batch, out_dim },
[&](Var i, Var j) {
return mm(i, j) + bias_val(j);
return mm(i, j) + bias(j);
}, "tensor", kBroadcast);
}
......@@ -67,8 +66,8 @@ inline tvm::Tensor dense_cuda(const Target& target,
* \return A schedule for the given ops.
*/
inline Schedule schedule_dense(const Target &target, const Array<Tensor>& outs) {
if (target.target_name == "cuda" &&
target.libs.count("cublas") > 0) {
if (target->target_name == "cuda" &&
target->libs().count("cublas")) {
return topi::generic::schedule_extern(target, outs);
}
......
......@@ -28,7 +28,7 @@ namespace cuda {
inline Schedule ScheduleOutputForExtern(Target target, Operation op, Schedule sch) {
auto x = op.output(0);
auto fused = detail::Fuse(sch[x], sch[x]->op.as<ComputeOpNode>()->axis);
auto num_thread = target.max_num_threads;
auto num_thread = target->max_num_threads;
IterVar bx, tx;
sch[x].split(fused, num_thread, &bx, &tx);
sch[x].bind(bx, tvm::thread_axis(Range(), "blockIdx.x"));
......
......@@ -25,7 +25,7 @@ namespace cuda {
inline void ScheduleInjectiveOp(const Target &target, Operation op, Schedule s) {
auto x = op.output(0);
auto fused = detail::Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
auto num_thread = target.max_num_threads;
auto num_thread = target->max_num_threads;
IterVar bx, tx;
s[x].split(fused, num_thread, &bx, &tx);
s[x].bind(bx, thread_axis(Range(), "blockIdx.x"));
......
......@@ -34,7 +34,7 @@ inline Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) {
auto _schedule = [&](const Tensor& padded_input, const Tensor& pool) {
s[padded_input].compute_inline();
auto num_thread = target.max_num_threads;
auto num_thread = target->max_num_threads;
Tensor out;
Tensor OL;
if (detail::contains(s->outputs, pool->op)) {
......
......@@ -51,7 +51,7 @@ Schedule ScheduleReduce(const Target& target,
if (out_stage->op.as<ComputeOpNode>()->axis.size() > 0) {
all_reduce = false;
num_thread = 32;
if (target.target_name == "opencl") {
if (target->target_name == "opencl") {
// Without this, CL_INVALID_WORK_GROUP_SIZE occurs with python tests.
// Don't know why.
num_thread = 16;
......@@ -61,7 +61,7 @@ Schedule ScheduleReduce(const Target& target,
thread_y = tvm::thread_axis(Range(0, num_thread), "threadIdx.y");
} else {
all_reduce = true;
num_thread = target.max_num_threads;
num_thread = target->max_num_threads;
thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
}
......
......@@ -20,17 +20,17 @@ using namespace tvm;
*
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim] (optional)
* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
*
* \return Tensor with shape [batch, out_dim]
*/
inline tvm::Tensor dense(const tvm::Tensor& data,
const tvm::Tensor& weight,
tvm::Tensor* bias) {
const tvm::Tensor& bias) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias != nullptr) {
CHECK_EQ((*bias)->shape.size(), 1) << "dense requires 1-D bias";
if (bias.defined()) {
CHECK_EQ(bias->shape.size(), 1) << "dense requires 1-D bias";
}
auto batch = data->shape[0];
......@@ -44,12 +44,11 @@ inline tvm::Tensor dense(const tvm::Tensor& data,
return tvm::sum(data(i, k) * weight(j, k), { k });
}, "tensor", "dense");
if (bias != nullptr) {
auto bias_val = *bias;
if (bias.defined()) {
matmul = tvm::compute(
{ batch, out_dim },
[&](Var i, Var j) {
return matmul(i, j) + bias_val(j);
return matmul(i, j) + bias(j);
}, "tensor", kBroadcast);
}
......
......@@ -25,31 +25,30 @@ namespace rocm {
* \param target The target device
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim] (optional)
* \param bias Tensor with shape [out_dim]. Optional; to omit bias, pass Tensor()
*
* \return Tensor with shape [batch, out_dim]
*/
inline tvm::Tensor dense_rocm(const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
tvm::Tensor* bias) {
const tvm::Tensor& bias) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias != nullptr) {
CHECK_EQ((*bias)->shape.size(), 1) << "dense requires 1-D bias";
if (bias.defined()) {
CHECK_EQ(bias->shape.size(), 1) << "dense requires 1-D bias";
}
auto batch = data->shape[0];
auto in_dim = data->shape[1];
auto out_dim = weight->shape[0];
if (target.libs.count("rocblas") > 0) {
if (target->libs().count("rocblas")) {
auto mm = topi::contrib::rocblas_matmul(data, weight, false, true);
if (bias != nullptr) {
auto bias_val = *bias;
if (bias.defined()) {
mm = tvm::compute({ batch, out_dim },
[&](Var i, Var j) {
return mm(i, j) + bias_val(j);
return mm(i, j) + bias(j);
}, "tensor", kBroadcast);
}
......@@ -68,8 +67,8 @@ inline tvm::Tensor dense_rocm(const Target& target,
* \return A schedule for the given ops.
*/
inline Schedule schedule_dense(const Target &target, const Array<Tensor>& outs) {
if (target.target_name == "rocm" &&
target.libs.count("rocblas") > 0) {
if (target->target_name == "rocm" &&
target->libs().count("rocblas")) {
return topi::generic::schedule_extern(target, outs);
}
......
......@@ -11,6 +11,10 @@ from __future__ import absolute_import as _abs
from tvm._ffi.libinfo import __version__
# Ensure C++ schedules get registered first, so python schedules can
# override them.
from . import cpp
from .math import *
from .tensor import *
from .reduction import *
......@@ -24,7 +28,6 @@ from . import mali
from . import opengl
from . import util
from . import rocm
from . import cpp
from . import vision
# not import testing by default
# because testing can have extra deps that are not necessary
......
......@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs
import tvm
@tvm.target.generic_func
@tvm.target.override_native_generic_func("schedule_injective")
def schedule_injective(outs):
"""Schedule for injective op.
......
......@@ -106,7 +106,7 @@ def schedule_depthwise_conv2d_nhwc(outs):
return _default_schedule(outs, False)
@tvm.target.generic_func
@tvm.target.override_native_generic_func("schedule_reduce")
def schedule_reduce(outs):
"""Schedule for reduction
......@@ -124,7 +124,7 @@ def schedule_reduce(outs):
return _default_schedule(outs, True)
@tvm.target.generic_func
@tvm.target.override_native_generic_func("schedule_softmax")
def schedule_softmax(outs):
"""Schedule for softmax
......@@ -142,7 +142,7 @@ def schedule_softmax(outs):
return _default_schedule(outs, False)
@tvm.target.generic_func
@tvm.target.override_native_generic_func("schedule_dense")
def schedule_dense(outs):
"""Schedule for dense
......@@ -160,7 +160,7 @@ def schedule_dense(outs):
return _default_schedule(outs, False)
@tvm.target.generic_func
@tvm.target.override_native_generic_func("schedule_pool")
def schedule_pool(outs):
"""Schedule for pool
......@@ -178,7 +178,7 @@ def schedule_pool(outs):
return _default_schedule(outs, False)
@tvm.target.generic_func
@tvm.target.override_native_generic_func("schedule_global_pool")
def schedule_global_pool(outs):
"""Schedule for global pool
......@@ -195,7 +195,7 @@ def schedule_global_pool(outs):
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
@tvm.target.override_native_generic_func("schedule_binarize_pack")
def schedule_binarize_pack(outs):
"""Schedule for binarize_pack
......@@ -213,7 +213,7 @@ def schedule_binarize_pack(outs):
return _default_schedule(outs, False)
@tvm.target.generic_func
@tvm.target.override_native_generic_func("schedule_binary_dense")
def schedule_binary_dense(outs):
"""Schedule for binary_dense
......
......@@ -39,7 +39,7 @@ def dense_default(data, weight, bias=None):
return matmul
@tvm.target.generic_func
@tvm.target.override_native_generic_func("dense")
def dense(data, weight, bias=None):
"""Applies a linear transformation: :math:`Y = XW^T + b`.
......
......@@ -51,6 +51,7 @@ struct extension_class_info<tvm::Target> {
} // namespace runtime
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
......@@ -281,15 +282,7 @@ TVM_REGISTER_GLOBAL("topi.nn.binary_dense")
/* Ops from nn/dense.h */
TVM_REGISTER_GLOBAL("topi.nn.dense")
.set_body([](TVMArgs args, TVMRetValue *rv) {
Tensor bias_val;
Tensor *bias;
if (args[2].type_code() == kNull) {
bias = nullptr;
} else {
bias_val = args[2];
bias = &bias_val;
}
*rv = nn::dense(args[0], args[1], bias);
*rv = nn::dense(args[0], args[1], args[2]);
});
/* Ops from nn/dilate.h */
......@@ -388,15 +381,7 @@ TVM_REGISTER_GLOBAL("topi.x86.schedule_injective")
/* ROCm schedules */
TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda")
.set_body([](TVMArgs args, TVMRetValue *rv) {
Tensor bias_val;
Tensor *bias;
if (args[3].type_code() == kNull) {
bias = nullptr;
} else {
bias_val = args[3];
bias = &bias_val;
}
*rv = rocm::dense_rocm(args[0], args[1], args[2], bias);
*rv = rocm::dense_rocm(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense")
......@@ -407,15 +392,7 @@ TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense")
/* CUDA schedules */
TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda")
.set_body([](TVMArgs args, TVMRetValue *rv) {
Tensor bias_val;
Tensor *bias;
if (args[3].type_code() == kNull) {
bias = nullptr;
} else {
bias_val = args[3];
bias = &bias_val;
}
*rv = cuda::dense_cuda(args[0], args[1], args[2], bias);
*rv = cuda::dense_cuda(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense")
......@@ -453,4 +430,106 @@ TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax")
*rv = topi::cuda::schedule_softmax(args[0], args[1]);
});
/*! \brief Builder function for instantiating schedules. */
using FTVMScheduleBuilder = std::function<
tvm::Schedule(const tvm::Target& target, const tvm::Array<tvm::Tensor>& outs)>;
/*!
* \brief Helper function for registering generic functions matching the
* FTVMScheduleBuilder signature. The schedule builder function is wrapped
* with a PackedFunc suitable for passing to a tvm::GenericFunc.
*
* \param builder The schedule builder to wrap.
*
* \return The wrapped schedule builder
*/
inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) {
return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
auto target = Target::current_target(false);
Array<Tensor> outs;
NodeRef argNodeRef = args[0];
if (argNodeRef->type_index() == outs->type_index()) {
outs = args[0];
} else {
outs = Array<Tensor> { args[0] };
}
*ret = builder(target, outs);
});
}
TVM_REGISTER_GENERIC_FUNC(schedule_injective)
.set_default(WrapSchedule(topi::generic::schedule_injective))
.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_injective))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_injective));
TVM_REGISTER_GENERIC_FUNC(schedule_softmax)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_softmax));
TVM_REGISTER_GENERIC_FUNC(schedule_dense)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_dense))
.register_func({ "rocm" }, WrapSchedule(topi::rocm::schedule_dense));
TVM_REGISTER_GENERIC_FUNC(schedule_pool)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_pool));
TVM_REGISTER_GENERIC_FUNC(schedule_global_pool)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_global_pool));
TVM_REGISTER_GENERIC_FUNC(schedule_reduce)
.set_default(WrapSchedule(topi::generic::default_schedule_auto_inline))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule_auto_inline))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_reduce));
TVM_REGISTER_GENERIC_FUNC(schedule_binarize_pack)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binarize_pack));
TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binary_dense));
/*! \brief Builder function for instantiating dense ops. */
using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
const tvm::Tensor& bias)>;
/*!
* \brief Helper function for registering dense ops matching the
* FTVMDenseOpBuilder signature. The op builder function is wrapped
* with a PackedFunc suitable for passing to a tvm::GenericFunc.
*
* \param builder The op builder to wrap.
*
* \return The wrapped op builder
*/
inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) {
return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
auto target = Target::current_target(false);
Tensor data = args[0];
Tensor weight = args[1];
Tensor bias = args[2];
*ret = builder(target, data, weight, bias);
});
}
TVM_REGISTER_GENERIC_FUNC(dense)
.set_default(WrapDenseOp([](const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
const tvm::Tensor& bias) {
return topi::nn::dense(data, weight, bias);
}))
.register_func({ "cuda", "gpu" }, WrapDenseOp(topi::cuda::dense_cuda))
.register_func({ "rocm" }, WrapDenseOp(topi::rocm::dense_rocm));
} // namespace topi
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment