Unverified Commit 00b23049 by Haichen Shen Committed by GitHub

[Topi] Breakdown topi.cc into smaller files (#5253)

* [Topi] Breakdown topi.cc into smaller files

* add missing file
parent 0cc26614
......@@ -45,9 +45,9 @@ namespace topi {
* \return A Tensor whose op member is a broadcast operation
*/
inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t,
const tvm::Array<tvm::PrimExpr>& output_shape,
std::string name = "T_broadcast_to",
std::string tag = kBroadcast) {
const tvm::Array<tvm::PrimExpr>& output_shape,
std::string name = "T_broadcast_to",
std::string tag = kBroadcast) {
CHECK_GE(output_shape.size(), t->shape.size())
<< "Not a broadcast, output dimensionality smaller than input.\noutput: "
<< output_shape << "\nvs\ninput: " << t;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Topi utility function
* \file topi/util.h
*/
#ifndef TOPI_UTIL_H_
#define TOPI_UTIL_H_
#include <tvm/ir/expr.h>
#include <tvm/runtime/packed_func.h>
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
/*! \brief Canonicalize an argument that may be Array<Expr> or int to Array<Expr> */
inline Array<Integer> ArrayOrInt(TVMArgValue arg) {
if (arg.type_code() == kDLInt || arg.type_code() == kDLUInt) {
Array<Integer> result;
result.push_back(arg.operator int());
return result;
} else {
return arg;
}
}
inline bool IsTensorType(TVMArgValue arg) {
return (arg.type_code() == kTVMObjectHandle &&
static_cast<Object*>(
arg.value().v_handle)->IsInstance<tvm::te::TensorNode>());
}
} // namespace topi
#endif // TOPI_UTIL_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Registration of broadcast operators
* \file broadcast.cc
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <topi/broadcast.h>
#include <topi/util.h>
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
#define TOPI_REGISTER_BCAST_OP(OpName, Op) \
TVM_REGISTER_GLOBAL(OpName) \
.set_body([](TVMArgs args, TVMRetValue *rv) { \
bool lhs_is_tensor = IsTensorType(args[0]); \
bool rhs_is_tensor = IsTensorType(args[1]); \
if (lhs_is_tensor && rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::te::Tensor(), \
args[1].operator tvm::te::Tensor()); \
} else if (!lhs_is_tensor && rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::PrimExpr(), \
args[1].operator tvm::te::Tensor()); \
} else if (lhs_is_tensor && !rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::te::Tensor(), \
args[1].operator tvm::PrimExpr()); \
} else if (!lhs_is_tensor && !rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::PrimExpr(), \
args[1].operator tvm::PrimExpr()); \
} \
}); \
TOPI_REGISTER_BCAST_OP("topi.add", topi::add);
TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract);
TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply);
TOPI_REGISTER_BCAST_OP("topi.divide", topi::divide);
TOPI_REGISTER_BCAST_OP("topi.floor_divide", topi::floor_divide);
TOPI_REGISTER_BCAST_OP("topi.mod", topi::mod);
TOPI_REGISTER_BCAST_OP("topi.floor_mod", topi::floor_mod);
TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum);
TOPI_REGISTER_BCAST_OP("topi.minimum", topi::minimum);
TOPI_REGISTER_BCAST_OP("topi.power", topi::power);
TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and);
TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or);
TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and);
TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or);
TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor);
TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift);
TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater);
TOPI_REGISTER_BCAST_OP("topi.less", topi::less);
TOPI_REGISTER_BCAST_OP("topi.equal", topi::equal);
TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal);
TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal);
TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal);
TVM_REGISTER_GLOBAL("topi.broadcast_to")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = broadcast_to(args[0], args[1]);
});
} // namespace topi
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Registration of elemwise operators
* \file elemwise.cc
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <topi/elemwise.h>
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
TVM_REGISTER_GLOBAL("topi.exp")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = exp(args[0]);
});
TVM_REGISTER_GLOBAL("topi.fast_exp")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = fast_exp(args[0]);
});
TVM_REGISTER_GLOBAL("topi.erf")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = erf(args[0]);
});
TVM_REGISTER_GLOBAL("topi.tan")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tan(args[0]);
});
TVM_REGISTER_GLOBAL("topi.cos")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cos(args[0]);
});
TVM_REGISTER_GLOBAL("topi.sin")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sin(args[0]);
});
TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]);
});
TVM_REGISTER_GLOBAL("topi.fast_tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = fast_tanh(args[0]);
});
TVM_REGISTER_GLOBAL("topi.atan")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = atan(args[0]);
});
TVM_REGISTER_GLOBAL("topi.sigmoid")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sigmoid(args[0]);
});
TVM_REGISTER_GLOBAL("topi.sqrt")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sqrt(args[0]);
});
TVM_REGISTER_GLOBAL("topi.rsqrt")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = rsqrt(args[0]);
});
TVM_REGISTER_GLOBAL("topi.log")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = log(args[0]);
});
TVM_REGISTER_GLOBAL("topi.identity")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = identity(args[0]);
});
TVM_REGISTER_GLOBAL("topi.negative")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = negative(args[0]);
});
TVM_REGISTER_GLOBAL("topi.clip")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = clip(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.cast")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cast(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.reinterpret")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = reinterpret(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.elemwise_sum")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = elemwise_sum(args[0]);
});
TVM_REGISTER_GLOBAL("topi.sign")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sign(args[0]);
});
TVM_REGISTER_GLOBAL("topi.full")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = full(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.full_like")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = full_like(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.logical_not")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = logical_not(args[0]);
});
TVM_REGISTER_GLOBAL("topi.bitwise_not")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = bitwise_not(args[0]);
});
} // namespace topi
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Registration of NN operators
* \file nn.cc
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <topi/nn.h>
#include <topi/nn/bias_add.h>
#include <topi/nn/bnn.h>
#include <topi/nn/dense.h>
#include <topi/nn/dilate.h>
#include <topi/nn/flatten.h>
#include <topi/nn/mapping.h>
#include <topi/nn/pooling.h>
#include <topi/nn/softmax.h>
#include <topi/nn/local_response_norm.h>
#include <topi/nn/batch_matmul.h>
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
/* Ops from nn.h */
TVM_REGISTER_GLOBAL("topi.nn.relu")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = relu<float>(args[0]);
});
TVM_REGISTER_GLOBAL("topi.nn.leaky_relu")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = leaky_relu(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.nn.prelu")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = prelu(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.nn.pad")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = pad(args[0], args[1], args[2], args[3]);
});
/* Ops from nn/dense.h */
TVM_REGISTER_GLOBAL("topi.nn.dense")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::dense(args[0], args[1], args[2], args[3]);
});
/* Ops from nn/bias_add.h */
TVM_REGISTER_GLOBAL("topi.nn.bias_add")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::bias_add(args[0], args[1], args[2]);
});
/* Ops from nn/batch_matmul.h */
TVM_REGISTER_GLOBAL("topi.nn.batch_matmul")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::batch_matmul(args[0], args[1]);
});
/* Ops from nn/dilate.h */
TVM_REGISTER_GLOBAL("topi.nn.dilate")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::dilate(args[0], args[1]);
});
/* Ops from nn/flatten.h */
TVM_REGISTER_GLOBAL("topi.nn.flatten")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::flatten(args[0]);
});
/* Ops from nn/mapping.h */
TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nchw")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::scale_shift_nchw(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::scale_shift_nhwc(args[0], args[1], args[2]);
});
/* Ops from nn/pooling.h */
TVM_REGISTER_GLOBAL("topi.nn.pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::pool(args[0], args[1], args[2], args[3],
static_cast<nn::PoolType>(static_cast<int>(args[4])),
args[5], args[6], args[7]);
});
TVM_REGISTER_GLOBAL("topi.nn.pool_grad")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::pool_grad(args[0], args[1], args[2], args[3], args[4],
static_cast<nn::PoolType>(static_cast<int>(args[5])),
args[6], args[7], args[8]);
});
TVM_REGISTER_GLOBAL("topi.nn.global_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::global_pool(args[0],
static_cast<nn::PoolType>(static_cast<int>(args[1])), args[2]);
});
TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::adaptive_pool(args[0], args[1],
static_cast<nn::PoolType>(static_cast<int>(args[2])),
args[3]);
});
TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::adaptive_pool3d(args[0], args[1],
static_cast<nn::PoolType>(static_cast<int>(args[2])),
args[3]);
});
TVM_REGISTER_GLOBAL("topi.nn.pool1d")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::pool1d(args[0], args[1], args[2], args[3],
static_cast<nn::PoolType>(static_cast<int>(args[4])),
args[5], args[6], args[7]);
});
TVM_REGISTER_GLOBAL("topi.nn.pool3d")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::pool3d(args[0], args[1], args[2], args[3],
static_cast<nn::PoolType>(static_cast<int>(args[4])),
args[5], args[6], args[7]);
});
/* Ops from nn/softmax.h */
TVM_REGISTER_GLOBAL("topi.nn.softmax")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::softmax(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.nn.log_softmax")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::log_softmax(args[0]);
});
TVM_REGISTER_GLOBAL("topi.nn.lrn")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::lrn(args[0], args[1], args[2],
static_cast<double>(args[3]),
static_cast<double>(args[4]),
static_cast<double>(args[5]));
});
/* Ops from nn/bnn.h */
TVM_REGISTER_GLOBAL("topi.nn.binarize_pack")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::binarize_pack(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.nn.binary_dense")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::binary_dense(args[0], args[1]);
});
} // namespace topi
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Registration of reduction operators
* \file reduction.cc
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <topi/reduction.h>
#include <topi/util.h>
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
TVM_REGISTER_GLOBAL("topi.sum")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::sum(args[0], ArrayOrInt(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.min")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::min(args[0], ArrayOrInt(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.max")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::max(args[0], ArrayOrInt(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.argmin")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::argmin(args[0], ArrayOrInt(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.argmax")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.prod")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.all")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.any")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]);
});
} // namespace topi
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Registration of TVM schedules
* \file schedule.cc
*/
#define TOPI_REDUCE_ATLEAST1D 0
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir/expr.h>
#include <tvm/target/generic_func.h>
#include <topi/generic/default.h>
#include <topi/generic/extern.h>
#include <topi/generic/injective.h>
#include <topi/cuda/dense.h>
#include <topi/cuda/injective.h>
#include <topi/cuda/pooling.h>
#include <topi/cuda/reduction.h>
#include <topi/cuda/softmax.h>
#include <topi/cuda/normalization.h>
#include <topi/x86/bnn.h>
#include <topi/x86/default.h>
#include <topi/x86/injective.h>
#include <topi/rocm/dense.h>
#include <topi/rocm/injective.h>
#include <topi/rocm/pooling.h>
#include <topi/rocm/reduction.h>
#include <topi/rocm/softmax.h>
#include <topi/rocm/normalization.h>
#include <topi/detail/tensor_utils.h>
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
TVM_REGISTER_GLOBAL("topi.TEST_create_target")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tvm::Target::Create(args[0]);
});
/* Generic schedules */
TVM_REGISTER_GLOBAL("topi.generic.default_schedule")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args[2]) {
*rv = topi::generic::default_schedule_auto_inline(args[0], args[1]);
} else {
*rv = topi::generic::default_schedule(args[0], args[1]);
}
});
TVM_REGISTER_GLOBAL("topi.generic.schedule_extern")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::generic::schedule_extern(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.generic.schedule_injective")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::generic::schedule_injective(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::generic::schedule_injective_from_existing(args[0], args[1]);
});
/* x86 schedules */
TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::x86::schedule_binarize_pack(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::x86::schedule_binary_dense(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.x86.default_schedule")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args[2]) {
*rv = topi::x86::default_schedule_auto_inline(args[0], args[1]);
} else {
*rv = topi::x86::default_schedule(args[0], args[1]);
}
});
TVM_REGISTER_GLOBAL("topi.x86.schedule_injective")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::x86::schedule_injective(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::x86::schedule_injective_from_existing(args[0], args[1]);
});
/* ROCm schedules */
TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = rocm::dense_rocm(args[0], args[1], args[2], args[3], args[4]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_dense(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_injective(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective_from_existing")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_pool(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_global_pool(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_reduce(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_softmax(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_lrn(args[0]);
});
/* CUDA schedules */
TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cuda::dense_cuda(args[0], args[1], args[2], args[3], args[4]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_dense(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_injective(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_pool(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_global_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_global_pool(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_reduce")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_reduce(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_softmax(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_lrn(args[0]);
});
/* Utility functions */
TVM_REGISTER_GLOBAL("topi.util.is_empty_shape")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::detail::is_empty_shape(args[0]);
});
TVM_REGISTER_GLOBAL("topi.util.bilinear_sample_nchw")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = detail::bilinear_sample_nchw(args[0], args[1], args[2], args[3]);
});
/*! \brief Builder function for instantiating schedules. */
using FTVMScheduleBuilder = std::function<
tvm::te::Schedule(const tvm::Target& target, const tvm::Array<tvm::te::Tensor>& outs)>;
/*!
* \brief Helper function for registering generic functions matching the
* FTVMScheduleBuilder signature. The schedule builder function is wrapped
* with a PackedFunc suitable for passing to a tvm::GenericFunc.
*
* \param builder The schedule builder to wrap.
*
* \return The wrapped schedule builder
*/
inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) {
return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
auto target = Target::Current(false);
Array<Tensor> outs;
ObjectRef argNodeRef = args[0];
if (argNodeRef->type_index() == outs->type_index()) {
outs = args[0];
} else {
outs = Array<Tensor> { args[0] };
}
*ret = builder(target, outs);
});
}
TVM_REGISTER_GENERIC_FUNC(schedule_injective)
.set_default(WrapSchedule(topi::generic::schedule_injective))
.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_injective))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_injective));
TVM_REGISTER_GENERIC_FUNC(schedule_softmax)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_softmax));
TVM_REGISTER_GENERIC_FUNC(schedule_dense)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_dense))
.register_func({ "rocm" }, WrapSchedule(topi::rocm::schedule_dense));
TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul)
.set_default(WrapSchedule(topi::generic::default_schedule));
TVM_REGISTER_GENERIC_FUNC(schedule_pool)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_pool));
TVM_REGISTER_GENERIC_FUNC(schedule_global_pool)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_global_pool));
TVM_REGISTER_GENERIC_FUNC(schedule_reduce)
.set_default(WrapSchedule(topi::generic::default_schedule_auto_inline))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule_auto_inline))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_reduce));
TVM_REGISTER_GENERIC_FUNC(schedule_binarize_pack)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binarize_pack));
TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binary_dense));
/*! \brief Builder function for instantiating schedules from existing schedules. */
using FTVMScheduleFromExistingBuilder = std::function<
tvm::te::Schedule(tvm::te::Schedule sch, const tvm::te::Tensor& out)>;
/*!
* \brief Helper function for registering generic functions matching the
* FTVMScheduleFromExistingBuilder signature. The schedule builder function is wrapped
* with a PackedFunc suitable for passing to a tvm::GenericFunc.
*
* \param builder The schedule builder to wrap.
*
* \return The wrapped schedule builder
*/
inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder) {
return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
*ret = builder(args[0], args[1]);
});
}
TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing)
.set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing))
.register_func({ "cpu" }, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing))
.register_func({ "cuda", "gpu" }, WrapScheduleFromExisting(
topi::cuda::schedule_injective_from_existing));
/*! \brief Builder function for instantiating dense ops. */
using FTVMDenseOpBuilder = std::function<tvm::te::Tensor(const Target& target,
const tvm::te::Tensor& data,
const tvm::te::Tensor& weight,
const tvm::te::Tensor& bias,
const DataType& out_dtype)>;
/*!
* \brief Helper function for registering dense ops matching the
* FTVMDenseOpBuilder signature. The op builder function is wrapped
* with a PackedFunc suitable for passing to a tvm::GenericFunc.
*
* \param builder The op builder to wrap.
*
* \return The wrapped op builder
*/
inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) {
return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
auto target = Target::Current(false);
Tensor data = args[0];
Tensor weight = args[1];
Tensor bias = args[2];
DataType out_dtype = args[3];
*ret = builder(target, data, weight, bias, out_dtype);
});
}
TVM_REGISTER_GENERIC_FUNC(dense)
.set_default(WrapDenseOp([](const Target& target,
const tvm::te::Tensor& data,
const tvm::te::Tensor& weight,
const tvm::te::Tensor& bias,
const DataType& out_dtype) {
return topi::nn::dense(data, weight, bias, out_dtype);
}))
.register_func({ "cuda", "gpu" }, WrapDenseOp(topi::cuda::dense_cuda))
.register_func({ "rocm" }, WrapDenseOp(topi::rocm::dense_rocm));
} // namespace topi
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Registration of TVM operators and schedules
* \file topi.cc
*/
#define TOPI_REDUCE_ATLEAST1D 0
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir/expr.h>
#include <tvm/target/generic_func.h>
#include <topi/broadcast.h>
#include <topi/elemwise.h>
#include <topi/nn.h>
#include <topi/reduction.h>
#include <topi/transform.h>
#include <topi/nn/bias_add.h>
#include <topi/nn/bnn.h>
#include <topi/nn/dense.h>
#include <topi/nn/dilate.h>
#include <topi/nn/flatten.h>
#include <topi/nn/mapping.h>
#include <topi/nn/pooling.h>
#include <topi/nn/softmax.h>
#include <topi/nn/local_response_norm.h>
#include <topi/nn/batch_matmul.h>
#include <topi/vision/reorg.h>
#include <topi/generic/default.h>
#include <topi/generic/extern.h>
#include <topi/generic/injective.h>
#include <topi/cuda/dense.h>
#include <topi/cuda/injective.h>
#include <topi/cuda/pooling.h>
#include <topi/cuda/reduction.h>
#include <topi/cuda/softmax.h>
#include <topi/cuda/normalization.h>
#include <topi/x86/bnn.h>
#include <topi/x86/default.h>
#include <topi/x86/injective.h>
#include <topi/rocm/dense.h>
#include <topi/rocm/injective.h>
#include <topi/rocm/pooling.h>
#include <topi/rocm/reduction.h>
#include <topi/rocm/softmax.h>
#include <topi/rocm/normalization.h>
#include <topi/detail/tensor_utils.h>
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
/*! \brief Canonicalize an argument that may be Array<Expr> or int to Array<Expr> */
Array<Integer> ArrayOrInt(TVMArgValue arg) {
if (arg.type_code() == kDLInt || arg.type_code() == kDLUInt) {
Array<Integer> result;
result.push_back(arg.operator int());
return result;
} else {
return arg;
}
}
inline bool IsTensorType(TVMArgValue arg) {
return (arg.type_code() == kTVMObjectHandle &&
static_cast<Object*>(
arg.value().v_handle)->IsInstance<tvm::te::TensorNode>());
}
TVM_REGISTER_GLOBAL("topi.TEST_create_target")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tvm::Target::Create(args[0]);
});
/* Ops from broadcast.h */
#define TOPI_REGISTER_BCAST_OP(OpName, Op) \
TVM_REGISTER_GLOBAL(OpName) \
.set_body([](TVMArgs args, TVMRetValue *rv) { \
bool lhs_is_tensor = IsTensorType(args[0]); \
bool rhs_is_tensor = IsTensorType(args[1]); \
if (lhs_is_tensor && rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::te::Tensor(), \
args[1].operator tvm::te::Tensor()); \
} else if (!lhs_is_tensor && rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::PrimExpr(), \
args[1].operator tvm::te::Tensor()); \
} else if (lhs_is_tensor && !rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::te::Tensor(), \
args[1].operator tvm::PrimExpr()); \
} else if (!lhs_is_tensor && !rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::PrimExpr(), \
args[1].operator tvm::PrimExpr()); \
} \
}); \
TOPI_REGISTER_BCAST_OP("topi.add", topi::add);
TOPI_REGISTER_BCAST_OP("topi.subtract", topi::subtract);
TOPI_REGISTER_BCAST_OP("topi.multiply", topi::multiply);
TOPI_REGISTER_BCAST_OP("topi.divide", topi::divide);
TOPI_REGISTER_BCAST_OP("topi.floor_divide", topi::floor_divide);
TOPI_REGISTER_BCAST_OP("topi.mod", topi::mod);
TOPI_REGISTER_BCAST_OP("topi.floor_mod", topi::floor_mod);
TOPI_REGISTER_BCAST_OP("topi.maximum", topi::maximum);
TOPI_REGISTER_BCAST_OP("topi.minimum", topi::minimum);
TOPI_REGISTER_BCAST_OP("topi.power", topi::power);
TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and);
TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or);
TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and);
TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or);
TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor);
TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift);
TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater);
TOPI_REGISTER_BCAST_OP("topi.less", topi::less);
TOPI_REGISTER_BCAST_OP("topi.equal", topi::equal);
TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal);
TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal);
TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal);
TVM_REGISTER_GLOBAL("topi.broadcast_to")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = broadcast_to(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.logical_not")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = logical_not(args[0]);
});
TVM_REGISTER_GLOBAL("topi.bitwise_not")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = bitwise_not(args[0]);
});
/* Ops from elemwise.h */
TVM_REGISTER_GLOBAL("topi.exp")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = exp(args[0]);
});
TVM_REGISTER_GLOBAL("topi.fast_exp")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = fast_exp(args[0]);
});
TVM_REGISTER_GLOBAL("topi.erf")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = erf(args[0]);
});
TVM_REGISTER_GLOBAL("topi.tan")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tan(args[0]);
});
TVM_REGISTER_GLOBAL("topi.cos")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cos(args[0]);
});
TVM_REGISTER_GLOBAL("topi.sin")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sin(args[0]);
});
TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]);
});
TVM_REGISTER_GLOBAL("topi.fast_tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = fast_tanh(args[0]);
});
TVM_REGISTER_GLOBAL("topi.atan")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = atan(args[0]);
});
TVM_REGISTER_GLOBAL("topi.sigmoid")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sigmoid(args[0]);
});
TVM_REGISTER_GLOBAL("topi.sqrt")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sqrt(args[0]);
});
TVM_REGISTER_GLOBAL("topi.rsqrt")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = rsqrt(args[0]);
});
TVM_REGISTER_GLOBAL("topi.log")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = log(args[0]);
});
TVM_REGISTER_GLOBAL("topi.identity")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = identity(args[0]);
});
TVM_REGISTER_GLOBAL("topi.negative")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = negative(args[0]);
});
TVM_REGISTER_GLOBAL("topi.clip")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = clip(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.cast")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cast(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.reinterpret")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = reinterpret(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.elemwise_sum")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = elemwise_sum(args[0]);
});
TVM_REGISTER_GLOBAL("topi.sign")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = sign(args[0]);
});
TVM_REGISTER_GLOBAL("topi.full")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = full(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.full_like")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = full_like(args[0], args[1]);
});
/* Ops from nn.h */
TVM_REGISTER_GLOBAL("topi.nn.relu")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = relu<float>(args[0]);
});
TVM_REGISTER_GLOBAL("topi.nn.leaky_relu")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = leaky_relu(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.nn.prelu")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = prelu(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.nn.pad")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = pad(args[0], args[1], args[2], args[3]);
});
/* Ops from reduction.h */
TVM_REGISTER_GLOBAL("topi.sum")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::sum(args[0], ArrayOrInt(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.min")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::min(args[0], ArrayOrInt(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.max")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::max(args[0], ArrayOrInt(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.argmin")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::argmin(args[0], ArrayOrInt(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.argmax")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.prod")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.all")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]);
});
TVM_REGISTER_GLOBAL("topi.any")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]);
});
/* Ops from transform.h */
TVM_REGISTER_GLOBAL("topi.expand_dims")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = expand_dims(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.transpose")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = transpose(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.flip")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = flip(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.reshape")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = reshape(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.squeeze")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = squeeze(args[0], ArrayOrInt(args[1]));
});
TVM_REGISTER_GLOBAL("topi.concatenate")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = concatenate(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.stack")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = stack(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.shape")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = shape(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.ndarray_size")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = ndarray_size(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.split")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) {
*rv = split_sections(args[0], args[1], args[2]);
} else {
*rv = split(args[0], args[1], args[2]);
}
});
TVM_REGISTER_GLOBAL("topi.layout_transform")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = layout_transform(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.take")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args.size() == 3) {
std::string mode = args[2];
*rv = take(args[0], args[1], mode);
} else {
int axis = args[2];
std::string mode = args[3];
*rv = take(args[0], args[1], axis, mode);
}
});
TVM_REGISTER_GLOBAL("topi.sequence_mask")
.set_body([](TVMArgs args, TVMRetValue *rv) {
double pad_val = args[2];
int axis = args[3];
*rv = sequence_mask(args[0], args[1], pad_val, axis);
});
TVM_REGISTER_GLOBAL("topi.where")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = where(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.arange")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = arange(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_GLOBAL("topi.repeat")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = repeat(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.tile")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.gather_nd")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = gather_nd(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.unravel_index")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = unravel_index(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.matmul")
.set_body([](TVMArgs args, TVMRetValue *rv) {
switch ( args.size() ) {
case 2: *rv = matmul(args[0], args[1]); break;
case 3: *rv = matmul(args[0], args[1], args[2]); break;
case 4: *rv = matmul(args[0], args[1], args[2], args[3]); break;
default: CHECK(0) << "topi.matmul expects 2, 3 or 4 arguments";
}});
TVM_REGISTER_GLOBAL("topi.tensordot")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args.size() == 2) {
*rv = tensordot(args[0], args[1]);
} else if (args.size() == 3) {
*rv = tensordot(args[0], args[1], args[2]);
} else {
Array<PrimExpr> axes = args[3];
*rv = tensordot(args[0], args[1], args[2], axes);
}
});
TVM_REGISTER_GLOBAL("topi.strided_slice")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = strided_slice(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_GLOBAL("topi.one_hot")
.set_body([](TVMArgs args, TVMRetValue *rv) {
int depth = args[3];
int axis = args[4];
DataType dtype = args[5];
*rv = one_hot(args[0], args[1], args[2], depth, axis, dtype);
});
/* Ops from nn/bnn.h */
TVM_REGISTER_GLOBAL("topi.nn.binarize_pack")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::binarize_pack(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.nn.binary_dense")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::binary_dense(args[0], args[1]);
});
/* Ops from nn/dense.h */
TVM_REGISTER_GLOBAL("topi.nn.dense")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::dense(args[0], args[1], args[2], args[3]);
});
/* Ops from nn/bias_add.h */
TVM_REGISTER_GLOBAL("topi.nn.bias_add")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::bias_add(args[0], args[1], args[2]);
});
/* Ops from nn/batch_matmul.h */
TVM_REGISTER_GLOBAL("topi.nn.batch_matmul")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::batch_matmul(args[0], args[1]);
});
/* Ops from nn/dilate.h */
TVM_REGISTER_GLOBAL("topi.nn.dilate")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::dilate(args[0], args[1]);
});
/* Ops from nn/flatten.h */
TVM_REGISTER_GLOBAL("topi.nn.flatten")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::flatten(args[0]);
});
/* Ops from nn/mapping.h */
TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nchw")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::scale_shift_nchw(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.nn.scale_shift_nhwc")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::scale_shift_nhwc(args[0], args[1], args[2]);
});
/* Ops from nn/pooling.h */
TVM_REGISTER_GLOBAL("topi.nn.pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::pool(args[0], args[1], args[2], args[3],
static_cast<nn::PoolType>(static_cast<int>(args[4])),
args[5], args[6], args[7]);
});
TVM_REGISTER_GLOBAL("topi.nn.pool_grad")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::pool_grad(args[0], args[1], args[2], args[3], args[4],
static_cast<nn::PoolType>(static_cast<int>(args[5])),
args[6], args[7], args[8]);
});
TVM_REGISTER_GLOBAL("topi.nn.global_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::global_pool(args[0],
static_cast<nn::PoolType>(static_cast<int>(args[1])), args[2]);
});
TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::adaptive_pool(args[0], args[1],
static_cast<nn::PoolType>(static_cast<int>(args[2])),
args[3]);
});
TVM_REGISTER_GLOBAL("topi.nn.adaptive_pool3d")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::adaptive_pool3d(args[0], args[1],
static_cast<nn::PoolType>(static_cast<int>(args[2])),
args[3]);
});
TVM_REGISTER_GLOBAL("topi.nn.pool1d")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::pool1d(args[0], args[1], args[2], args[3],
static_cast<nn::PoolType>(static_cast<int>(args[4])),
args[5], args[6], args[7]);
});
TVM_REGISTER_GLOBAL("topi.nn.pool3d")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::pool3d(args[0], args[1], args[2], args[3],
static_cast<nn::PoolType>(static_cast<int>(args[4])),
args[5], args[6], args[7]);
});
/* Ops from nn/softmax.h */
TVM_REGISTER_GLOBAL("topi.nn.softmax")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::softmax(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.nn.log_softmax")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::log_softmax(args[0]);
});
TVM_REGISTER_GLOBAL("topi.nn.lrn")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::lrn(args[0], args[1], args[2],
static_cast<double>(args[3]),
static_cast<double>(args[4]),
static_cast<double>(args[5]));
});
TVM_REGISTER_GLOBAL("topi.vision.reorg")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = vision::reorg(args[0], args[1]);
});
/* Generic schedules */
TVM_REGISTER_GLOBAL("topi.generic.default_schedule")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args[2]) {
*rv = topi::generic::default_schedule_auto_inline(args[0], args[1]);
} else {
*rv = topi::generic::default_schedule(args[0], args[1]);
}
});
TVM_REGISTER_GLOBAL("topi.generic.schedule_extern")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::generic::schedule_extern(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.generic.schedule_injective")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::generic::schedule_injective(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.generic.schedule_injective_from_existing")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::generic::schedule_injective_from_existing(args[0], args[1]);
});
/* x86 schedules */
TVM_REGISTER_GLOBAL("topi.x86.schedule_binarize_pack")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::x86::schedule_binarize_pack(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.x86.schedule_binary_dense")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::x86::schedule_binary_dense(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.x86.default_schedule")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args[2]) {
*rv = topi::x86::default_schedule_auto_inline(args[0], args[1]);
} else {
*rv = topi::x86::default_schedule(args[0], args[1]);
}
});
TVM_REGISTER_GLOBAL("topi.x86.schedule_injective")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::x86::schedule_injective(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.x86.schedule_injective_from_existing")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::x86::schedule_injective_from_existing(args[0], args[1]);
});
/* ROCm schedules */
TVM_REGISTER_GLOBAL("topi.rocm.dense_cuda")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = rocm::dense_rocm(args[0], args[1], args[2], args[3], args[4]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_dense(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_injective(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective_from_existing")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_pool(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_global_pool(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_reduce(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_softmax(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_lrn(args[0]);
});
/* CUDA schedules */
TVM_REGISTER_GLOBAL("topi.cuda.dense_cuda")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = cuda::dense_cuda(args[0], args[1], args[2], args[3], args[4]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_dense")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_dense(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_injective(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_injective_from_existing")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_injective_from_existing(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_pool(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_global_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_global_pool(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_reduce")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_reduce(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_softmax")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_softmax(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.cuda.schedule_lrn")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::cuda::schedule_lrn(args[0]);
});
/* Utility functions */
TVM_REGISTER_GLOBAL("topi.util.is_empty_shape")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::detail::is_empty_shape(args[0]);
});
TVM_REGISTER_GLOBAL("topi.util.bilinear_sample_nchw")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = detail::bilinear_sample_nchw(args[0], args[1], args[2], args[3]);
});
/*! \brief Builder function for instantiating schedules. */
using FTVMScheduleBuilder = std::function<
tvm::te::Schedule(const tvm::Target& target, const tvm::Array<tvm::te::Tensor>& outs)>;
/*!
* \brief Helper function for registering generic functions matching the
* FTVMScheduleBuilder signature. The schedule builder function is wrapped
* with a PackedFunc suitable for passing to a tvm::GenericFunc.
*
* \param builder The schedule builder to wrap.
*
* \return The wrapped schedule builder
*/
inline PackedFunc WrapSchedule(FTVMScheduleBuilder builder) {
return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
auto target = Target::Current(false);
Array<Tensor> outs;
ObjectRef argNodeRef = args[0];
if (argNodeRef->type_index() == outs->type_index()) {
outs = args[0];
} else {
outs = Array<Tensor> { args[0] };
}
*ret = builder(target, outs);
});
}
TVM_REGISTER_GENERIC_FUNC(schedule_injective)
.set_default(WrapSchedule(topi::generic::schedule_injective))
.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_injective))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_injective));
TVM_REGISTER_GENERIC_FUNC(schedule_softmax)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_softmax));
TVM_REGISTER_GENERIC_FUNC(schedule_dense)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_dense))
.register_func({ "rocm" }, WrapSchedule(topi::rocm::schedule_dense));
TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul)
.set_default(WrapSchedule(topi::generic::default_schedule));
TVM_REGISTER_GENERIC_FUNC(schedule_pool)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_pool));
TVM_REGISTER_GENERIC_FUNC(schedule_global_pool)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_global_pool));
TVM_REGISTER_GENERIC_FUNC(schedule_reduce)
.set_default(WrapSchedule(topi::generic::default_schedule_auto_inline))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule_auto_inline))
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_reduce));
TVM_REGISTER_GENERIC_FUNC(schedule_binarize_pack)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binarize_pack));
TVM_REGISTER_GENERIC_FUNC(schedule_binary_dense)
.set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::schedule_binary_dense));
/*! \brief Builder function for instantiating schedules from existing schedules. */
using FTVMScheduleFromExistingBuilder = std::function<
tvm::te::Schedule(tvm::te::Schedule sch, const tvm::te::Tensor& out)>;
/*!
* \brief Helper function for registering generic functions matching the
* FTVMScheduleFromExistingBuilder signature. The schedule builder function is wrapped
* with a PackedFunc suitable for passing to a tvm::GenericFunc.
*
* \param builder The schedule builder to wrap.
*
* \return The wrapped schedule builder
*/
inline PackedFunc WrapScheduleFromExisting(FTVMScheduleFromExistingBuilder builder) {
return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
*ret = builder(args[0], args[1]);
});
}
TVM_REGISTER_GENERIC_FUNC(schedule_injective_from_existing)
.set_default(WrapScheduleFromExisting(topi::generic::schedule_injective_from_existing))
.register_func({ "cpu" }, WrapScheduleFromExisting(topi::x86::schedule_injective_from_existing))
.register_func({ "cuda", "gpu" }, WrapScheduleFromExisting(
topi::cuda::schedule_injective_from_existing));
/*! \brief Builder function for instantiating dense ops. */
using FTVMDenseOpBuilder = std::function<tvm::te::Tensor(const Target& target,
const tvm::te::Tensor& data,
const tvm::te::Tensor& weight,
const tvm::te::Tensor& bias,
const DataType& out_dtype)>;
/*!
* \brief Helper function for registering dense ops matching the
* FTVMDenseOpBuilder signature. The op builder function is wrapped
* with a PackedFunc suitable for passing to a tvm::GenericFunc.
*
* \param builder The op builder to wrap.
*
* \return The wrapped op builder
*/
inline PackedFunc WrapDenseOp(FTVMDenseOpBuilder builder) {
return PackedFunc([builder](TVMArgs args, TVMRetValue* ret) {
auto target = Target::Current(false);
Tensor data = args[0];
Tensor weight = args[1];
Tensor bias = args[2];
DataType out_dtype = args[3];
*ret = builder(target, data, weight, bias, out_dtype);
});
}
TVM_REGISTER_GENERIC_FUNC(dense)
.set_default(WrapDenseOp([](const Target& target,
const tvm::te::Tensor& data,
const tvm::te::Tensor& weight,
const tvm::te::Tensor& bias,
const DataType& out_dtype) {
return topi::nn::dense(data, weight, bias, out_dtype);
}))
.register_func({ "cuda", "gpu" }, WrapDenseOp(topi::cuda::dense_cuda))
.register_func({ "rocm" }, WrapDenseOp(topi::rocm::dense_rocm));
} // namespace topi
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Registration of transform operators
* \file transform.cc
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <topi/transform.h>
#include <topi/util.h>
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
TVM_REGISTER_GLOBAL("topi.expand_dims")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = expand_dims(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.transpose")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = transpose(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.flip")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = flip(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.reshape")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = reshape(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.squeeze")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = squeeze(args[0], ArrayOrInt(args[1]));
});
TVM_REGISTER_GLOBAL("topi.concatenate")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = concatenate(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.stack")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = stack(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.shape")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = shape(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.ndarray_size")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = ndarray_size(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.split")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) {
*rv = split_sections(args[0], args[1], args[2]);
} else {
*rv = split(args[0], args[1], args[2]);
}
});
TVM_REGISTER_GLOBAL("topi.layout_transform")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = layout_transform(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.take")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args.size() == 3) {
std::string mode = args[2];
*rv = take(args[0], args[1], mode);
} else {
int axis = args[2];
std::string mode = args[3];
*rv = take(args[0], args[1], axis, mode);
}
});
TVM_REGISTER_GLOBAL("topi.sequence_mask")
.set_body([](TVMArgs args, TVMRetValue *rv) {
double pad_val = args[2];
int axis = args[3];
*rv = sequence_mask(args[0], args[1], pad_val, axis);
});
TVM_REGISTER_GLOBAL("topi.where")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = where(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.arange")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = arange(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_GLOBAL("topi.repeat")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = repeat(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("topi.tile")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.gather_nd")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = gather_nd(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.unravel_index")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = unravel_index(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.matmul")
.set_body([](TVMArgs args, TVMRetValue *rv) {
switch ( args.size() ) {
case 2: *rv = matmul(args[0], args[1]); break;
case 3: *rv = matmul(args[0], args[1], args[2]); break;
case 4: *rv = matmul(args[0], args[1], args[2], args[3]); break;
default: CHECK(0) << "topi.matmul expects 2, 3 or 4 arguments";
}});
TVM_REGISTER_GLOBAL("topi.tensordot")
.set_body([](TVMArgs args, TVMRetValue *rv) {
if (args.size() == 2) {
*rv = tensordot(args[0], args[1]);
} else if (args.size() == 3) {
*rv = tensordot(args[0], args[1], args[2]);
} else {
Array<PrimExpr> axes = args[3];
*rv = tensordot(args[0], args[1], args[2], axes);
}
});
TVM_REGISTER_GLOBAL("topi.strided_slice")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = strided_slice(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_GLOBAL("topi.one_hot")
.set_body([](TVMArgs args, TVMRetValue *rv) {
int depth = args[3];
int axis = args[4];
DataType dtype = args[5];
*rv = one_hot(args[0], args[1], args[2], depth, axis, dtype);
});
} // namespace topi
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \brief Registration of vision operators
* \file vision.cc
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <topi/vision/reorg.h>
namespace topi {
using namespace tvm;
using namespace tvm::runtime;
TVM_REGISTER_GLOBAL("topi.vision.reorg")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = vision::reorg(args[0], args[1]);
});
} // namespace topi
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment