Commit fd6560e1 by Thomas Viehmann Committed by masahi

add rocm schedules to topi C++ (#4507)

This imports the CUDA schedules to rocm.
parent 40f1886c
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file rocm/injective.h
* \brief rocm schedule for injective operations
*/
#ifndef TOPI_ROCM_INJECTIVE_H_
#define TOPI_ROCM_INJECTIVE_H_
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/operation.h"
#include "tvm/build_module.h"
#include "topi/cuda/injective.h"
namespace topi {
using namespace tvm;
namespace rocm {
/*!
* \brief Updates an existing schedule for the given injective ops.
*
* \param sch The schedule to update.
* \param out The tensor representing the injective op.
*
* \return The updated schedule.
*/
inline Schedule schedule_injective_from_existing(Schedule sch, const Tensor& out) {
return topi::cuda::schedule_injective_from_existing(sch, out);
}
/*!
* \brief Create a rocm schedule for the given output tensors.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_injective(const Target &target, const Array<Tensor>& outs) {
return topi::cuda::schedule_injective(target, outs);
}
} // namespace rocm
} // namespace topi
#endif // TOPI_ROCM_INJECTIVE_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file rocm/pooling.h
* \brief rocm schedule for pooling operations
*/
#ifndef TOPI_ROCM_POOLING_H_
#define TOPI_ROCM_POOLING_H_
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "topi/detail/array_utils.h"
#include "tvm/operation.h"
#include "tvm/build_module.h"
#include "topi/cuda/pooling.h"
namespace topi {
using namespace tvm;
namespace rocm {
/*!
* \brief Create a rocm schedule for pool
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) {
return topi::cuda::schedule_pool(target, outs);
}
/*!
* \brief Create a rocm schedule for global_pool
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_global_pool(const Target &target, const Array<Tensor>& outs) {
return topi::cuda::schedule_global_pool(target, outs);
}
} // namespace rocm
} // namespace topi
#endif // TOPI_ROCM_POOLING_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file rocm/reduction.h
* \brief rocm schedule for reduction operations
*/
#ifndef TOPI_ROCM_REDUCTION_H_
#define TOPI_ROCM_REDUCTION_H_
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/operation.h"
#include "tvm/build_module.h"
#include "topi/cuda/reduction.h"
namespace topi {
using namespace tvm;
namespace rocm {
/*!
* \brief Create a rocm schedule for a reduce operation.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
Schedule schedule_reduce(const Target& target, Array<Tensor> outs) {
return topi::cuda::schedule_reduce(target, outs);
}
} // namespace rocm
} // namespace topi
#endif // TOPI_ROCM_REDUCTION_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file rocm/injective.h
* \brief ROCM schedule for injective operations
*/
#ifndef TOPI_ROCM_SOFTMAX_H_
#define TOPI_ROCM_SOFTMAX_H_
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/operation.h"
#include "tvm/build_module.h"
#include "topi/cuda/softmax.h"
namespace topi {
using namespace tvm;
namespace rocm {
/*!
* \brief Create a rocm schedule for the given softmax output tensors.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
inline Schedule schedule_softmax(const Target &target, const Array<Tensor>& outs) {
return topi::cuda::schedule_softmax(target, outs);
}
} // namespace rocm
} // namespace topi
#endif // TOPI_ROCM_SOFTMAX_H_
......@@ -66,6 +66,10 @@
#include <topi/x86/injective.h>
#include <topi/rocm/dense.h>
#include <topi/rocm/injective.h>
#include <topi/rocm/pooling.h>
#include <topi/rocm/reduction.h>
#include <topi/rocm/softmax.h>
#include <topi/rocm/normalization.h>
namespace topi {
......@@ -638,6 +642,36 @@ TVM_REGISTER_GLOBAL("topi.rocm.schedule_dense")
*rv = topi::rocm::schedule_dense(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_injective(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_injective_from_existing")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_injective_from_existing(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_pool(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_global_pool")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_global_pool(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_reduce")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_reduce(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_softmax")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_softmax(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("topi.rocm.schedule_lrn")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = topi::rocm::schedule_lrn(args[0], args[1]);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment