Commit f280f23a by alex-weaver Committed by Tianqi Chen

Porting schedules (except convolutions) to C++ (#763)

* Ported injective schedules to C++. Added some elementwise ops.

* Fix lint errors

* Added reduction ops and schedules

* Fix lint errors

* Fix lint errors

* Fix lint errors

* Added transform ops

* Fix lint errors

* Fix lint errors

* Added softmax, log_softmax, leaky_relu and flatten ops.
Fixed issue where TVM_DECLARE_INTRIN_UNARY used the PureExtern flag
instead of PureIntrinsic.
Added softmax CUDA schedule.

* Fix lint

* Fix lint

* Added binary_dense, batch_norm_inference, dense, dilate, scale_shift_*,
global_pool and pool ops.
Extended pad to allow specifying pad_value.
Fixed issue where pad would throw if padding was zero in all dimensions.

* Fix lint

* Fix lint

* Added CUDA schedules for dense, pool and global_pool

* Added extern schedules for generic and CUDA

* Fix lint

* Added x86 binary schedules

* Fix lint

* Added rocm dense schedule. Added rocBLAS and cuBLAS support to dense ops

* Added pow ops. Added x86 default and injective schedules

* Fix lint

* Fix lint

* Fix lint

* Fix lint

* Fix lint

* Fix indent

* Removed schedules directory

* Changed left_shift, right_shift to operators. Changed pad_value in pad() to remove pointer usage

* Fixed usage of pad in nn/pooling.h. Fixed declaration of operator>>

* Fixed comments for shift operators

* Added comments to utility functions

* Added TOPI C++ library, exporting broadcast_add op

* Fix lint

* Share libinfo.py with TVM

* Fix lint

* Add other broadcast ops

* Fix lint

* Fix imports in topi

* Fix lib names

* Fixed build issue where windows builds don't apply correct definitions

* Removed TVM_EXPORTS from topi library

* Attempted CI build fix

* Add topi lib to tvm_multilib

* Fix Jenkinsfile

* Added TOPI build target to Makefile

* Fix nn op namespaces.

* Fix lint

* Renamed TOPI lib to libtvm_topi

* Removed _ffi/base.py

* Remove _ffi from topi, now shared with tvm.

* Make libtvm_topi loading optional

* Fix compiler warnings

* Fix lint

* Fix lint

* Fix lint

* Fix build error by making new libs argument to Target optional

* Added C++ Target type interop. Added registration of remaining C++ ops and schedules. Added test of broadcast ops

* Fix lint

* Fix lint

* Fix compile error

* Fix compiler warnings

* Fix compiler warnings

* Fixed int vector interop. Fixed argmin incorrectly invoking argmax. Fixed corner case in default schedules of attempting to fuse 0 length axes. Added tests for reduce ops.

* Refactored reduce builders

* Fixed typos in topi.cc. Added basic test.

* Fixed padding size error. Added dense, dilate, pooling tests

* Fixed issue where clip would output a different dtype to the input. Added split_sections op to cover the other mode of the python split op. Added tests.

* Changed extension type numbers to avoid clash with NNVM

* Fix lint

* Fix compiler warnings

* Removed use of std::vector from the public TOPI API

* Fix lint

* Add TOPI C++ tests to CI

* Fixed detail namespacing. Improved comments.
parent 944de73b
......@@ -44,9 +44,7 @@ if(MSVC)
add_definitions(-DWIN32_LEAN_AND_MEAN)
add_definitions(-D_CRT_SECURE_NO_WARNINGS)
add_definitions(-D_SCL_SECURE_NO_WARNINGS)
add_definitions(-DTVM_EXPORTS)
add_definitions(-DHalide_SHARED)
add_definitions(-DHalide_EXPORTS)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj")
......@@ -82,6 +80,10 @@ file(GLOB COMPILER_SRCS
src/op/*.cc
src/schedule/*.cc
)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/topi/include)
file(GLOB TOPI_SRCS
topi/src/*.cc
)
file(GLOB_RECURSE HALIDEIR_SRCS HalideIR/src/*.cpp)
list(APPEND COMPILER_SRCS ${HALIDEIR_SRCS})
file(GLOB RUNTIME_SRCS src/runtime/*.cc)
......@@ -209,8 +211,10 @@ endif()
list(APPEND RUNTIME_SRCS ${GROUP_Include})
add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS})
add_library(tvm_topi SHARED ${TOPI_SRCS})
add_library(tvm_runtime SHARED ${RUNTIME_SRCS})
target_link_libraries(tvm ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS})
target_link_libraries(tvm_topi tvm ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS})
target_link_libraries(tvm_runtime ${TVM_RUNTIME_LINKER_LIBS})
install(TARGETS tvm_runtime DESTINATION lib${LIB_SUFFIX})
if (INSTALL_DEV)
......@@ -242,3 +246,10 @@ else(INSTALL_DEV)
PATTERN "*.h"
)
endif(INSTALL_DEV)
if(MSVC)
target_compile_definitions(tvm PRIVATE -DHalide_EXPORTS)
target_compile_definitions(tvm_runtime PRIVATE -DHalide_EXPORTS)
target_compile_definitions(tvm PRIVATE -DTVM_EXPORTS)
target_compile_definitions(tvm_runtime PRIVATE -DTVM_EXPORTS)
endif()
\ No newline at end of file
......@@ -4,10 +4,11 @@
// See documents at https://jenkins.io/doc/book/pipeline/jenkinsfile/
// tvm libraries
tvm_runtime = "lib/libtvm_runtime.so, config.mk"
tvm_lib = "lib/libtvm.so, " + tvm_runtime
topi_lib = "lib/libtopi.so"
tvm_runtime = "lib/libtvm_runtime.so, config.mk, "
tvm_lib = "lib/libtvm.so, " + tvm_runtime + topi_lib
// LLVM upstream lib
tvm_multilib = "lib/libtvm_llvm40.so, lib/libtvm_llvm50.so, lib/libtvm_llvm60.so, " + tvm_runtime
tvm_multilib = "lib/libtvm_llvm40.so, lib/libtvm_llvm50.so, lib/libtvm_llvm60.so, " + tvm_runtime + topi_lib
// command to start a docker container
docker_run = 'tests/ci_build/ci_build.sh'
......
......@@ -58,6 +58,7 @@ OPENGL_SRC = $(wildcard src/runtime/opengl/*.cc)
RPC_SRC = $(wildcard src/runtime/rpc/*.cc)
GRAPH_SRC = $(wildcard src/runtime/graph/*.cc)
RUNTIME_SRC = $(wildcard src/runtime/*.cc)
TOPI_SRC = $(wildcard topi/src/*.cc)
# Objectives
LLVM_BUILD = build/llvm${LLVM_VERSION}
......@@ -71,11 +72,13 @@ RPC_OBJ = $(patsubst src/%.cc, build/%.o, $(RPC_SRC))
GRAPH_OBJ = $(patsubst src/%.cc, build/%.o, $(GRAPH_SRC))
CC_OBJ = $(patsubst src/%.cc, build/%.o, $(CC_SRC)) $(LLVM_OBJ)
RUNTIME_OBJ = $(patsubst src/%.cc, build/%.o, $(RUNTIME_SRC))
TOPI_OBJ = $(patsubst topi/%.cc, build/%.o, $(TOPI_SRC))
CONTRIB_OBJ =
# Deps
ALL_DEP = $(CC_OBJ) $(CONTRIB_OBJ) $(LIB_HALIDEIR)
RUNTIME_DEP = $(RUNTIME_OBJ)
TOPI_DEP = $(TOPI_OBJ)
# Dependency specific rules
ifdef CUDA_PATH
......@@ -198,10 +201,11 @@ else
JVM_PKG_PROFILE := $(JVM_PKG_PROFILE)-cpu
endif
BUILD_TARGETS ?= lib/libtvm.$(SHARED_LIBRARY_SUFFIX) lib/libtvm_runtime.$(SHARED_LIBRARY_SUFFIX)
BUILD_TARGETS ?= lib/libtvm.$(SHARED_LIBRARY_SUFFIX) lib/libtvm_runtime.$(SHARED_LIBRARY_SUFFIX) lib/libtvm_topi.$(SHARED_LIBRARY_SUFFIX)
all: ${BUILD_TARGETS}
runtime: lib/libtvm_runtime.$(SHARED_LIBRARY_SUFFIX)
web: lib/libtvm_web_runtime.js lib/libtvm_web_runtime.bc
topi: lib/libtvm_topi.$(SHARED_LIBRARY_SUFFIX)
include tests/cpp/unittest.mk
......@@ -226,10 +230,19 @@ build/%.o: src/%.cc
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) -c $< -o $@
build/src/%.o: topi/src/%.cc
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/src/$*.o $< >build/src/$*.d
$(CXX) -c $(CFLAGS) -c $< -o $@
lib/libtvm.dylib: $(ALL_DEP) $(RUNTIME_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
lib/libtvm_topi.dylib: lib/libtvm.so $(TOPI_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -L./lib -ltvm -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
lib/libtvm_runtime.dylib: $(RUNTIME_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
......@@ -238,6 +251,10 @@ lib/libtvm.so: $(ALL_DEP) $(RUNTIME_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
lib/libtvm_topi.so: lib/libtvm.so $(TOPI_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -L./lib -ltvm -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
lib/libtvm_runtime.so: $(RUNTIME_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
......
......@@ -31,19 +31,23 @@ struct Target {
std::unordered_set<std::string> keys;
/*! \brief Options for this target */
std::vector<std::string> options;
/*! \brief Set of imported libs */
std::unordered_set<std::string> libs;
Target(const std::string& target_name,
DLDeviceType device_type,
int max_num_threads,
int thread_warp_size,
const std::unordered_set<std::string>& keys,
const std::vector<std::string>& options) :
const std::vector<std::string>& options,
const std::unordered_set<std::string>& libs = {}) :
target_name(target_name),
device_type(device_type),
max_num_threads(max_num_threads),
thread_warp_size(thread_warp_size),
keys(keys),
options(options) {
options(options),
libs(libs) {
}
/*! \return the full device string to pass to codegen::Build */
......
......@@ -73,7 +73,7 @@ inline int GetVectorBytes(Type dtype) {
/*! \brief a named variable in TVM */
class Var : public HalideIR::VarExpr {
public:
explicit Var(const std::string& name_hint = "v",
EXPORT explicit Var(const std::string& name_hint = "v",
Type t = Int(32)) : VarExpr(name_hint, t) {}
explicit Var(std::shared_ptr<Node> n) : VarExpr(n) {}
explicit Var(VarExpr v) : VarExpr(v) {}
......
......@@ -45,13 +45,19 @@ TVM_DLL Expr min(Expr source, Array<IterVar> axis);
// Unary intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \
return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureExtern); \
return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \
} \
TVM_DECLARE_INTRIN_UNARY(exp);
TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(log);
inline Expr pow(Expr x, Expr y) {
return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic);
}
} // namespace tvm
#endif // TVM_IR_OPERATOR_H_
......@@ -28,7 +28,7 @@ namespace ir {
* \param vrange The range information about the variable.
* \return Canonicalized statement.
*/
Expr Simplify(Expr expr, Map<Var, Range> vrange = Map<Var, Range>());
EXPORT Expr Simplify(Expr expr, Map<Var, Range> vrange = Map<Var, Range>());
/*!
* \brief Simplify the statement.
......@@ -62,7 +62,7 @@ Expr CanonicalSimplify(Expr expr,
* \param rhs The right operand
* \return The comparison result.
*/
bool Equal(const Expr& lhs, const Expr& rhs);
EXPORT bool Equal(const Expr& lhs, const Expr& rhs);
/*!
* \brief Deep compare lhs and rhs
......
......@@ -353,7 +353,7 @@ class ExternOpNode : public OperationNode {
v->Visit("inputs", &inputs);
v->Visit("body", &body);
}
static Operation make(std::string name,
EXPORT static Operation make(std::string name,
std::string tag,
Array<Tensor> inputs,
Array<Buffer> input_placeholders,
......
......@@ -56,24 +56,24 @@ class Stage : public NodeRef {
* \brief set the memory scope of the stage
* \param scope The memory scope.
*/
Stage& set_scope(std::string scope); // NOLINT(*)
EXPORT Stage& set_scope(std::string scope); // NOLINT(*)
/*!
* \brief specify the schedule to be computed at the parent schedule's scope.
* \param parent The parent schedule.
* \param scope The iteration point to carry the schedule.
* \return reference to self.
*/
Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
EXPORT Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
/*!
* \brief Compute the function inline.
* \return reference to self.
*/
Stage& compute_inline(); // NOLINT(*)
EXPORT Stage& compute_inline(); // NOLINT(*)
/*!
* \brief Compute the function at group root.
* \return reference to self.
*/
Stage& compute_root(); // NOLINT(*)
EXPORT Stage& compute_root(); // NOLINT(*)
/*!
* \brief Bind the ivar to thread index.
*
......@@ -92,7 +92,7 @@ class Stage : public NodeRef {
* \param predicate The condition to be checked.
* \return reference to self.
*/
Stage& set_store_predicate(Expr predicate);
EXPORT Stage& set_store_predicate(Expr predicate);
/*!
* \brief Specify environment threads that launched around the group's scope.
* This can only be used in group stage.
......@@ -101,7 +101,7 @@ class Stage : public NodeRef {
* This is a beta feature.
* \return reference to self.
*/
Stage& env_threads(Array<IterVar> threads);
EXPORT Stage& env_threads(Array<IterVar> threads);
/*!
* \brief Split the parent by factor, generate
* \param parent The parent iteration domain.
......@@ -120,7 +120,7 @@ class Stage : public NodeRef {
* \param p_inner The result inner domain.
* \return reference to self.
*/
Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
EXPORT Stage& split_by_nparts(IterVar parent, Expr nparts, IterVar* p_outer, IterVar* p_inner); // NOLINT(*)
/*!
* \brief Fuse the inner outer domain to the target
* \param outer The outer domain to be fused.
......@@ -128,13 +128,13 @@ class Stage : public NodeRef {
* \param p_target The result target domain.
* \return reference to self.
*/
Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
EXPORT Stage& fuse(IterVar outer, IterVar inner, IterVar* p_target); // NOLINT(*)
/*!
* \brief Reorder the iteration
* \param order The order of iteration variable.
* \return reference to self.
*/
Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
EXPORT Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
/*!
* \brief Perform tiling on two dimensions
* The final loop order from outmost to inner most are
......@@ -150,7 +150,7 @@ class Stage : public NodeRef {
* \param p_y_inner Inner axis of y dimension
* \return reference to self.
*/
Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
EXPORT Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
Expr x_factor, Expr y_factor,
IterVar* p_x_outer, IterVar* p_y_outer,
IterVar* p_x_inner, IterVar* p_y_inner);
......@@ -159,7 +159,7 @@ class Stage : public NodeRef {
* \param var The axis to be vectorized.
* \return reference to self.
*/
Stage& vectorize(IterVar var); // NOLINT(*)
EXPORT Stage& vectorize(IterVar var); // NOLINT(*)
/*!
* \brief Replace computation of the current stage by tensor intrinsic f.
* \param var The axis marks beginning of tensorization.
......@@ -167,19 +167,19 @@ class Stage : public NodeRef {
* \param f The Tensor compute intrinsics.
* \return reference to self.
*/
Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
EXPORT Stage& tensorize(IterVar var, TensorIntrin f); // NOLINT(*)
/*!
* \brief Unroll iteration.
* \param var The axis to be unrolled.
* \return reference to self.
*/
Stage& unroll(IterVar var); // NOLINT(*)
EXPORT Stage& unroll(IterVar var); // NOLINT(*)
/*!
* \brief Parallelize iteration.
* \param var The axis to be parallelized.
* \return reference to self.
*/
Stage& parallel(IterVar var); // NOLINT(*)
EXPORT Stage& parallel(IterVar var); // NOLINT(*)
/*!
* \brief Annotate the iteration with pragma
*
......@@ -188,7 +188,7 @@ class Stage : public NodeRef {
*
* \return reference to self.
*/
Stage& pragma(IterVar var, const std::string& pragma_type); // NOLINT(*)
EXPORT Stage& pragma(IterVar var, const std::string& pragma_type); // NOLINT(*)
/*!
* \brief Fetch data in advance.
* \param domain the tensor to be prefetched
......@@ -196,7 +196,7 @@ class Stage : public NodeRef {
* \param offset the number of iterations be to fetched in advance
* \return reference to self
*/
Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
EXPORT Stage& prefetch(const Tensor &domain, IterVar var, Expr offset); //NOLINT(*)
/*!
* \brief Set alignment requirement for specific dimension.
*
......@@ -207,12 +207,12 @@ class Stage : public NodeRef {
* \param offset The required offset factor.
* \return reference to self
*/
Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
EXPORT Stage& storage_align(IterVar axis, int factor, int offset); //NOLINT(*)
/*!
* \brief Compute current stage with double buffering.
* \return reference to self.
*/
Stage& double_buffer(); // NOLINT(*)
EXPORT Stage& double_buffer(); // NOLINT(*)
/*!
* \brief Schedule for OpenGL fragment shader.
* \return reference to self.
......@@ -271,7 +271,7 @@ class Schedule : public NodeRef {
* \param include_inputs Whether include inputs if they are reachable from outputs.
* \return The new grouped stage.
*/
Stage create_group(const Array<Tensor>& outputs,
EXPORT Stage create_group(const Array<Tensor>& outputs,
const Array<Tensor>& inputs,
bool include_inputs = false);
/*!
......@@ -283,7 +283,7 @@ class Schedule : public NodeRef {
* \param readers The readers to redirect to the tensor.
* \return The created tensor.
*/
Tensor cache_read(const Tensor& tensor,
EXPORT Tensor cache_read(const Tensor& tensor,
const std::string& scope,
const Array<Operation>& readers);
/*!
......@@ -302,7 +302,7 @@ class Schedule : public NodeRef {
* \param scope The scope of the storage.
* \return The created tensor.
*/
Tensor cache_write(const Tensor& tensor, const std::string& scope);
EXPORT Tensor cache_write(const Tensor& tensor, const std::string& scope);
/*!
* \brief Factor a reduction axis in tensor's schedule to be an explicit axis.
* This will create a new stage that generated the new tensor with axis
......@@ -315,7 +315,7 @@ class Schedule : public NodeRef {
* \param axis The reduction axis in tensor's schedule to be factored.
* \return The created factored tensors.
*/
Array<Tensor> rfactor(const Tensor& tensor,
EXPORT Array<Tensor> rfactor(const Tensor& tensor,
const IterVar& axis);
/*!
* \brief Normalize the schedule.
......
......@@ -48,7 +48,7 @@ void AutoInlineElemWise(Schedule sch);
*
* \param sch The schedule to be inlined.
*/
void AutoInlineInjective(Schedule sch);
EXPORT void AutoInlineInjective(Schedule sch);
} // namespace schedule
} // namespace tvm
......
......@@ -261,9 +261,12 @@ def _init_api(namespace):
mod : str
The name of the module.
"""
module = sys.modules[namespace]
assert namespace.startswith("tvm.")
prefix = namespace[4:]
_init_api_prefix(namespace, prefix)
def _init_api_prefix(module_name, prefix):
module = sys.modules[module_name]
for name in list_global_func_names():
if prefix == "api":
......
......@@ -2,9 +2,9 @@
from __future__ import absolute_import
import sys
import os
import warnings
def find_lib_path(name=None, search_path=None):
def find_lib_path(name=None, search_path=None, optional=False):
"""Find dynamic library files.
Parameters
......@@ -56,7 +56,12 @@ def find_lib_path(name=None, search_path=None):
else:
dll_path.append(search_path)
if name is not None:
lib_dll_path = [os.path.join(p, name) for p in dll_path]
if isinstance(name, list):
lib_dll_path = []
for n in name:
lib_dll_path += [os.path.join(p, n) for p in dll_path]
else:
lib_dll_path = [os.path.join(p, name) for p in dll_path]
runtime_dll_path = []
else:
if sys.platform.startswith('win32'):
......@@ -81,9 +86,14 @@ def find_lib_path(name=None, search_path=None):
lib_found = [p for p in runtime_dll_path if os.path.exists(p) and os.path.isfile(p)]
if not lib_found:
raise RuntimeError('Cannot find the files.\n' +
'List of candidates:\n' +
str('\n'.join(lib_dll_path + runtime_dll_path)))
message = ('Cannot find the files.\n' +
'List of candidates:\n' +
str('\n'.join(lib_dll_path + runtime_dll_path)))
if not optional:
raise RuntimeError(message)
else:
warnings.warn(message)
return None
if use_runtime:
sys.stderr.write("Loading runtime library %s... exec only\n" % lib_found[0])
......
......@@ -85,25 +85,25 @@ namespace target {
Target llvm() {
std::unordered_set<std::string> keys({ "llvm", "cpu" });
std::vector<std::string> options;
return Target("llvm", kDLCPU, 512, 1, keys, options);
return Target("llvm", kDLCPU, 512, 1, keys, options, {});
}
Target cuda() {
std::unordered_set<std::string> keys({ "cuda", "gpu" });
std::vector<std::string> options;
return Target("cuda", kDLGPU, 512, 32, keys, options);
return Target("cuda", kDLGPU, 512, 32, keys, options, {});
}
Target rocm() {
std::unordered_set<std::string> keys({ "rocm", "gpu" });
std::vector<std::string> options;
return Target("rocm", kDLROCM, 256, 1, keys, options);
return Target("rocm", kDLROCM, 256, 1, keys, options, {});
}
Target metal() {
std::unordered_set<std::string> keys({ "gpu" });
std::vector<std::string> options;
return Target("metal", kDLMetal, 256, 1, keys, options);
return Target("metal", kDLMetal, 256, 1, keys, options, {});
}
Target rasp() {
......@@ -114,7 +114,7 @@ Target rasp() {
"-mcpu=cortex-a53",
"-mattr=+neon"
});
return Target("llvm", kDLCPU, 512, 1, keys, options);
return Target("llvm", kDLCPU, 512, 1, keys, options, {});
}
Target mali() {
......@@ -129,7 +129,7 @@ Target mali() {
Target stackvm() {
std::unordered_set<std::string> keys({ "stackvm", "cpu" });
std::vector<std::string> options;
return Target("stackvm", kDLCPU, 512, 1, keys, options);
return Target("stackvm", kDLCPU, 512, 1, keys, options, {});
}
} // namespace target
......
......@@ -4,6 +4,7 @@
*/
#include <tvm/base.h>
#include <tvm/ir.h>
#include <tvm/ir_operator.h>
namespace tvm {
......
export PYTHONPATH=python:topi/python
python -m nose -v topi/tests/python_cpp || exit -1
python3 -m nose -v topi/tests/python_cpp || exit -1
......@@ -32,6 +32,7 @@ fi
if [ ${TASK} == "cpp_test" ] || [ ${TASK} == "all_test" ]; then
make -f dmlc-core/scripts/packages.mk gtest
./tests/scripts/task_cpp_unittest.sh || exit -1
./tests/scripts/task_cpp_topi.sh || exit -1
fi
if [ ${TASK} == "python_test" ] || [ ${TASK} == "all_test" ]; then
......
......@@ -33,7 +33,7 @@ inline tvm::Tensor broadcast_to(const tvm::Tensor& t,
<< output_shape << "\nvs\ninput: " << t;
auto bh = detail::BroadcastShape(output_shape, t->shape);
CHECK_EQ(output_shape.size(), bh.common_shape.size());
for (int i = 0; i < output_shape.size(); ++i) {
for (size_t i = 0; i < output_shape.size(); ++i) {
CHECK(tvm::ir::Equal(output_shape[i], bh.common_shape[i]));
}
auto l = [&](tvm::Array<tvm::Var> ovars) {
......@@ -147,6 +147,67 @@ inline tvm::Tensor broadcast_mod(const tvm::Tensor& A,
return detail::WithBroadcast(l, A, B, name, tag);
}
/*!
* \brief Creates an operation that performs pointwise maximum of 2 tensors
* and broadcasts them into a common compatible shape where necessary,
* according to numpy's rules
*
* \param A The first tensor
* \param B The second tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is a pointwise maximum with broadcast
*/
inline tvm::Tensor broadcast_maximum(const tvm::Tensor& A,
const tvm::Tensor& B,
std::string name = "tensor",
std::string tag = kBroadcast) {
auto l = [&](tvm::Expr a, tvm::Expr b) { return tvm::max(a, b); }; // NOLINT(*)
return detail::WithBroadcast(l, A, B, name, tag);
}
/*!
* \brief Creates an operation that performs pointwise minimum of 2 tensors
* and broadcasts them into a common compatible shape where necessary,
* according to numpy's rules
*
* \param A The first tensor
* \param B The second tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is a pointwise minimum with broadcast
*/
inline tvm::Tensor broadcast_minimum(const tvm::Tensor& A,
const tvm::Tensor& B,
std::string name = "tensor",
std::string tag = kBroadcast) {
auto l = [&](tvm::Expr a, tvm::Expr b) { return tvm::min(a, b); }; // NOLINT(*)
return detail::WithBroadcast(l, A, B, name, tag);
}
/*!
* \brief Creates an operation that raises one tensor to the power of another
* pointwise and broadcasts them into a common compatible shape where necessary,
* according to numpy's rules
*
* \param A The first tensor
* \param B The second tensor to compute pow(A, B)
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is a pointwise pow with
* broadcast
*/
inline tvm::Tensor broadcast_pow(const tvm::Tensor& A,
const tvm::Tensor& B,
std::string name = "tensor",
std::string tag = kBroadcast) {
auto l = [&](tvm::Expr a, tvm::Expr b) { return tvm::pow(a, b); };
return detail::WithBroadcast(l, A, B, name, tag);
}
} // namespace topi
#endif // TOPI_BROADCAST_H_
/*!
* Copyright (c) 2017 by Contributors
* \brief External function interface to cuBLAS libraries
* \file tags.h
*/
#ifndef TOPI_CONTRIB_CUBLAS_H_
#define TOPI_CONTRIB_CUBLAS_H_
#include "tvm/tvm.h"
#include "topi/detail/extern.h"
namespace topi {
namespace contrib {
using namespace tvm;
using namespace topi::detail;
/*!
* \brief Create an op that multiplies lhs and rhs with cuBLAS
*
* \param lhs The left matrix operand
* \param rhs The right matrix operand
* \param transa Whether to transpose lhs
* \param transb Whether to transpose rhs
*
* \return The output tensor
*/
inline Tensor cublas_matmul(const Tensor& lhs,
const Tensor& rhs,
bool transa,
bool transb) {
auto n = transa ? lhs->shape[1] : lhs->shape[0];
auto m = transb ? rhs->shape[0] : rhs->shape[1];
return make_extern(
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
Expr("tvm.contrib.cublas.matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
transa,
transb });
}, "C", "")[0];
}
} // namespace contrib
} // namespace topi
#endif // TOPI_CONTRIB_CUBLAS_H_
/*!
* Copyright (c) 2017 by Contributors
* \brief External function interface to rocBLAS libraries
* \file tags.h
*/
#ifndef TOPI_CONTRIB_ROCBLAS_H_
#define TOPI_CONTRIB_ROCBLAS_H_
#include "tvm/tvm.h"
#include "topi/detail/extern.h"
namespace topi {
namespace contrib {
using namespace tvm;
/*!
* \brief Create an op that multiplies lhs and rhs with rocBLAS
*
* \param lhs The left matrix operand
* \param rhs The right matrix operand
* \param transa Whether to transpose lhs
* \param transb Whether to transpose rhs
*
* \return The output tensor
*/
inline Tensor rocblas_matmul(const Tensor& lhs,
const Tensor& rhs,
bool transa,
bool transb) {
auto n = transa ? lhs->shape[1] : lhs->shape[0];
auto m = transb ? rhs->shape[0] : rhs->shape[1];
return make_extern(
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
Expr("tvm.contrib.rocblas.matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
transa,
transb });
}, "C", "")[0];
}
} // namespace contrib
} // namespace topi
#endif // TOPI_CONTRIB_ROCBLAS_H_
/*!
* Copyright (c) 2017 by Contributors
* \file cuda/dense.h
* \brief CUDA schedule for dense operation
*/
#ifndef TOPI_CUDA_DENSE_H_
#define TOPI_CUDA_DENSE_H_
#include "tvm/tvm.h"
#include "tvm/build_module.h"
#include "topi/tags.h"
#include "topi/detail/array_utils.h"
#include "topi/nn/dense.h"
#include "topi/contrib/cublas.h"
#include "topi/generic/extern.h"
namespace topi {
using namespace tvm;
namespace cuda {
/*!
* \brief Implementation of dense for CUDA backend
*
* \param target The target device
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim] (optional)
*
* \return Tensor with shape [batch, out_dim]
*/
inline tvm::Tensor dense_cuda(const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
tvm::Tensor* bias) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias != nullptr) {
CHECK_EQ((*bias)->shape.size(), 1) << "dense requires 1-D bias";
}
auto batch = data->shape[0];
auto in_dim = data->shape[1];
auto out_dim = weight->shape[0];
if (target.libs.count("cublas") > 0) {
auto mm = topi::contrib::cublas_matmul(data, weight, false, true);
if (bias != nullptr) {
auto bias_val = *bias;
mm = tvm::compute({ batch, out_dim },
[&](Var i, Var j) {
return mm(i, j) + bias_val(j);
}, "tensor", kBroadcast);
}
return mm;
} else {
return topi::nn::dense(data, weight, bias);
}
}
/*!
* \brief Create a CUDA schedule for dense
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
Schedule schedule_dense(const Target &target, const Array<Tensor>& outs) {
if (target.target_name == "cuda" &&
target.libs.count("cublas") > 0) {
return topi::generic::schedule_extern(target, outs);
}
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
auto _schedule = [&](const Tensor& dense) {
auto num_thread = 64;
auto k = dense->op.as<ComputeOpNode>()->reduce_axis[0];
IterVar ko, kf;
s[dense].split(k, num_thread, &ko, &kf);
auto dense_f = s.rfactor(dense, kf)[0];
Tensor out;
if (contains(s->outputs, dense->op)) {
out = dense;
} else {
out = outs[0]->op.output(0);
s[dense].compute_at(s[out], s[out]->op.as<ComputeOpNode>()->axis[1]);
}
s[out].bind(s[out]->op.as<ComputeOpNode>()->axis[0], tvm::thread_axis(Range(), "blockIdx.y"));
s[out].bind(s[out]->op.as<ComputeOpNode>()->axis[1], tvm::thread_axis(Range(), "blockIdx.x"));
auto tx = s[dense]->op.as<ComputeOpNode>()->reduce_axis[0];
auto thread_x = tvm::thread_axis(Range(), "threadIdx.x");
s[dense].bind(tx, thread_x);
s[dense_f].compute_at(s[dense], tx);
s[dense].set_store_predicate(static_cast<Expr>(thread_x) == 0);
s[out].set_store_predicate(static_cast<Expr>(thread_x) == 0);
};
std::function<void(Operation)> traverse;
traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output)
if (is_broadcast(op->tag)) {
if (!contains(s->outputs, op)) {
s[op].compute_inline();
}
for (auto tensor : op->InputTensors()) {
if (tensor->op->InputTensors().size() > 0) {
traverse(tensor->op);
}
}
} else if (op->tag == "dense") {
// If tag starts with global_pool
auto dense = op.output(0);
_schedule(dense);
} else {
LOG(ERROR) << "Unsupported operator " << op->tag;
}
};
traverse(outs[0]->op);
return s;
}
} // namespace cuda
} // namespace topi
#endif // TOPI_CUDA_DENSE_H_
/*!
* Copyright (c) 2017 by Contributors
* \file cuda/extern.h
* \brief CUDA schedule for extern followed by injective operations
*/
#ifndef TOPI_CUDA_EXTERN_H_
#define TOPI_CUDA_EXTERN_H_
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/tvm.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
namespace cuda {
/*!
* \brief Schedule a given operation representing one of the outputs of an
* external function which is followed by injective operations.
*
* \param target The target to generate a schedule for.
* \param op The operation representing the output followed by injective operations.
* \param sch The schedule to apply this scheduling to
*
* \return The schedule given by sch
*/
Schedule ScheduleOutputForExtern(Target target, Operation op, Schedule sch) {
auto x = op.output(0);
auto fused = Fuse(sch[x], sch[x]->op.as<ComputeOpNode>()->axis);
auto num_thread = target.max_num_threads;
IterVar bx, tx;
sch[x].split(fused, num_thread, &bx, &tx);
sch[x].bind(bx, tvm::thread_axis(Range(), "blockIdx.x"));
sch[x].bind(tx, tvm::thread_axis(Range(), "threadIdx.x"));
return sch;
}
/*!
* \brief Schedule an extern op followed by injective operations.
* For example, cudnn kernel + bias add + relu
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the op.
*/
Schedule schedule_extern(const Target& target, Array<Tensor> outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
tvm::schedule::AutoInlineInjective(s);
for (auto out : outs) {
if (out->op->derived_from<ExternOpNode>()) {
continue;
}
ScheduleOutputForExtern(target, out->op, s);
}
return s;
}
} // namespace cuda
} // namespace topi
#endif // TOPI_CUDA_EXTERN_H_
/*!
* Copyright (c) 2017 by Contributors
* \file cuda/injective.h
* \brief CUDA schedule for injective operations
*/
#ifndef TOPI_CUDA_INJECTIVE_H_
#define TOPI_CUDA_INJECTIVE_H_
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/tvm.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
namespace cuda {
/*!
* \brief Schedule a given injective operation.
*
* \param target The target to generate a schedule for.
* \param op The operation representing the injective operation.
* \param s The schedule to apply this scheduling to
*/
void ScheduleInjectiveOp(const Target &target, Operation op, Schedule s) {
auto x = op.output(0);
auto fused = Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
auto num_thread = target.max_num_threads;
IterVar bx, tx;
s[x].split(fused, num_thread, &bx, &tx);
s[x].bind(bx, thread_axis(Range(), "blockIdx.x"));
s[x].bind(tx, thread_axis(Range(), "threadIdx.x"));
}
/*!
* \brief Create a CUDA schedule for the given output tensors.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
Schedule schedule_injective(const Target &target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
tvm::schedule::AutoInlineInjective(s);
for (auto out : outs) {
ScheduleInjectiveOp(target, out->op, s);
}
return s;
}
} // namespace cuda
} // namespace topi
#endif // TOPI_CUDA_INJECTIVE_H_
/*!
* Copyright (c) 2017 by Contributors
* \file cuda/pooling.h
* \brief CUDA schedule for pooling operations
*/
#ifndef TOPI_CUDA_POOLING_H_
#define TOPI_CUDA_POOLING_H_
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "topi/detail/array_utils.h"
#include "tvm/tvm.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
namespace cuda {
/*!
* \brief Create a CUDA schedule for pool
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
Schedule schedule_pool(const Target &target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
auto _schedule = [&](const Tensor& padded_input, const Tensor& pool) {
s[padded_input].compute_inline();
auto num_thread = target.max_num_threads;
Tensor out;
Tensor OL;
if (contains(s->outputs, pool->op)) {
out = pool;
OL = s.cache_write(pool, "local");
} else {
out = outs[0]->op.output(0);
s[pool].set_scope("local");
}
auto fused = Fuse(s[out], s[out]->op.as<ComputeOpNode>()->axis);
IterVar bx, tx;
s[out].split(fused, num_thread, &bx, &tx);
s[out].bind(bx, tvm::thread_axis(Range(), "blockIdx.x"));
s[out].bind(tx, tvm::thread_axis(Range(), "threadIdx.x"));
if (contains(s->outputs, pool->op)) {
s[OL].compute_at(s[out], tx);
} else {
s[pool].compute_at(s[out], tx);
}
};
std::function<void(Operation)> traverse;
traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output)
if (is_broadcast(op->tag)) {
if (!contains(s->outputs, op)) {
s[op].compute_inline();
}
for (auto tensor : op->InputTensors()) {
if (tensor->op->InputTensors().size() > 0) {
traverse(tensor->op);
}
}
} else if (op->tag.rfind("pool", 0) == 0) {
// If tag starts with pool
auto padded_input = op->InputTensors()[0];
auto pool = op.output(0);
_schedule(padded_input, pool);
} else {
LOG(ERROR) << "Unsupported operator " << op->tag;
}
};
traverse(outs[0]->op);
return s;
}
/*!
* \brief Create a CUDA schedule for global_pool
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
Schedule schedule_global_pool(const Target &target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
auto _schedule = [&](const Tensor& pool) {
auto num_thread = 8;
auto block_x = tvm::thread_axis(Range(), "blockIdx.x");
auto block_y = tvm::thread_axis(Range(), "blockIdx.y");
auto thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
auto thread_y = tvm::thread_axis(Range(0, num_thread), "threadIdx.y");
Tensor out;
Tensor OL;
if (contains(s->outputs, pool->op)) {
out = pool;
OL = s.cache_write(pool, "local");
} else {
out = outs[0]->op.output(0);
s[pool].set_scope("local");
}
auto i = s[out]->op.as<ComputeOpNode>()->axis[0];
auto c = s[out]->op.as<ComputeOpNode>()->axis[1];
IterVar by, ty;
s[out].split(i, num_thread, &by, &ty);
IterVar bx, tx;
s[out].split(c, num_thread, &bx, &tx);
s[out].reorder({ by, bx, ty, tx });
s[out].bind(ty, thread_y);
s[out].bind(tx, thread_x);
s[out].bind(by, block_y);
s[out].bind(bx, block_x);
if (contains(s->outputs, pool->op)) {
s[OL].compute_at(s[out], tx);
} else {
s[pool].compute_at(s[out], tx);
}
};
std::function<void(Operation)> traverse;
traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output)
if (is_broadcast(op->tag)) {
if (!contains(s->outputs, op)) {
s[op].compute_inline();
}
for (auto tensor : op->InputTensors()) {
if (tensor->op->InputTensors().size() > 0) {
traverse(tensor->op);
}
}
} else if (op->tag.rfind("global_pool", 0) == 0) {
// If tag starts with global_pool
auto pool = op.output(0);
_schedule(pool);
} else {
LOG(ERROR) << "Unsupported operator " << op->tag;
}
};
traverse(outs[0]->op);
return s;
}
} // namespace cuda
} // namespace topi
#endif // TOPI_CUDA_POOLING_H_
/*!
* Copyright (c) 2017 by Contributors
* \file cuda/reduction.h
* \brief CUDA schedule for reduction operations
*/
#ifndef TOPI_CUDA_REDUCTION_H_
#define TOPI_CUDA_REDUCTION_H_
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/tvm.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
namespace cuda {
/*!
* \brief Schedule a given reduce operation.
*
* \param target The target to generate a schedule for.
* \param op The operation representing the injective operation.
* \param sch The schedule to apply this scheduling to
* \param is_idx_reduce Pass true to schedule a reduce op that returns
* an index, such as argmax or argmin.
*
* \return The schedule given by sch
*/
Schedule ScheduleReduce(const Target& target,
Operation op,
Schedule sch,
bool is_idx_reduce = false) {
Tensor data_out;
Tensor data_in;
if (!is_idx_reduce) {
data_in = op->InputTensors()[0];
data_out = op.output(0);
} else {
data_out = op->InputTensors()[0];
}
auto out_stage = sch[data_out];
CHECK_GT(out_stage->op.as<ComputeOpNode>()->reduce_axis.size(), 0) <<
"reduce_axis must be greater than zero";
bool all_reduce;
int num_thread;
IterVar block_x, thread_x, thread_y;
if (out_stage->op.as<ComputeOpNode>()->axis.size() > 0) {
all_reduce = false;
num_thread = 32;
if (target.target_name == "opencl") {
// Without this, CL_INVALID_WORK_GROUP_SIZE occurs with python tests.
// Don't know why.
num_thread = 16;
}
block_x = tvm::thread_axis(Range(), "blockIdx.x");
thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
thread_y = tvm::thread_axis(Range(0, num_thread), "threadIdx.y");
} else {
all_reduce = true;
num_thread = target.max_num_threads;
thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
}
auto fused_reduce = Fuse(out_stage, out_stage->op.as<ComputeOpNode>()->reduce_axis);
IterVar ko, ki;
out_stage.split(fused_reduce, num_thread, &ko, &ki);
auto data_out_rf = sch.rfactor(data_out, ki)[0];
auto tx = out_stage->op.as<ComputeOpNode>()->reduce_axis[0];
out_stage.bind(tx, thread_x);
sch[data_out_rf].compute_at(out_stage, tx);
Tensor real_output;
Tensor temp_idx_input, temp_val_input;
if (is_idx_reduce) {
real_output = op.output(0);
temp_idx_input = data_out->op.output(0);
temp_val_input = data_out->op.output(1);
} else {
real_output = data_out;
}
auto stage_real = sch[real_output];
if (!all_reduce) {
// Fuse and split the axis
auto fused_outer = Fuse(stage_real, stage_real->op.as<ComputeOpNode>()->axis);
IterVar bx, outer_in;
stage_real.split(fused_outer, num_thread, &bx, &outer_in);
// Bind the axes to threads and blocks
stage_real.bind(outer_in, thread_y);
stage_real.bind(bx, block_x);
if (is_idx_reduce) {
sch[temp_idx_input].compute_at(stage_real, outer_in);
sch[temp_val_input].compute_at(stage_real, outer_in);
}
} else {
if (is_idx_reduce) {
sch[temp_idx_input].compute_at(stage_real,
stage_real->op.as<ComputeOpNode>()->axis[0]);
sch[temp_val_input].compute_at(stage_real,
stage_real->op.as<ComputeOpNode>()->axis[0]);
}
}
stage_real.set_store_predicate(static_cast<Expr>(thread_x) == 0);
return sch;
}
/*!
* \brief Recursively traverse operator inputs, setting injective inputs
* to be computed inline.
*
* \param s The schedule we are building
* \param op The current op in the traversal
*/
void TraverseBeforeReduce(Schedule s, Operation op) {
if (op->derived_from<PlaceholderOpNode>()) {
return;
} else if (is_injective(op->tag)) {
s[op].compute_inline();
for (auto tensor : op->InputTensors()) {
TraverseBeforeReduce(s, tensor->op);
}
} else {
LOG(ERROR) << "Unsupported operator " << op->tag;
}
}
/*!
* \brief Schedule a reduce op, then invoke TraverseBeforeReduce on each
* of the op's inputs.
*
* \param target The target to generate a schedule for.
* \param s The schedule we are building
* \param op The reduce op
*/
void TraverseAfterReduce(const Target& target, Schedule s, Operation op) {
if (is_broadcast(op->tag)) {
LOG(ERROR) << "Elementwise op after reduce is not yet supported";
} else if (op->tag == kCommReduce) {
ScheduleReduce(target, op, s, false);
for (auto tensor : op->InputTensors()) {
TraverseBeforeReduce(s, tensor->op);
}
} else if (op->tag == kCommReduceIdx) {
ScheduleReduce(target, op, s, true);
for (auto tensor : op->InputTensors()[0]->op->InputTensors()) {
TraverseBeforeReduce(s, tensor->op);
}
} else {
LOG(ERROR) << "Unsupported operator " << op->tag;
}
}
/*!
* \brief Create a CUDA schedule for a reduce operation.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
Schedule schedule_reduce(const Target& target, Array<Tensor> outs) {
CHECK_EQ(outs.size(), 1) << "outs must have size 1";
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
TraverseAfterReduce(target, s, outs[0]->op);
return s;
}
} // namespace cuda
} // namespace topi
#endif // TOPI_CUDA_REDUCTION_H_
/*!
* Copyright (c) 2017 by Contributors
* \file cuda/injective.h
* \brief CUDA schedule for injective operations
*/
#ifndef TOPI_CUDA_SOFTMAX_H_
#define TOPI_CUDA_SOFTMAX_H_
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/tvm.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
namespace cuda {
/*!
* \brief Create a CUDA schedule for the given softmax output tensors.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
Schedule schedule_softmax(const Target &target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
auto softmax = outs[0];
auto max_elem = softmax->op->InputTensors()[1];
auto expsum = softmax->op->InputTensors()[2];
int num_thread = 64;
auto block_x = tvm::thread_axis(Range(), "blockIdx.x");
auto thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
s[max_elem].bind(max_elem->op.as<ComputeOpNode>()->axis[0], block_x);
auto k = expsum->op.as<ComputeOpNode>()->reduce_axis[0];
IterVar ko, ki;
s[expsum].split(k, num_thread, &ko, &ki);
auto EF = s.rfactor(expsum, ki)[0];
s[expsum].bind(s[expsum]->op.as<ComputeOpNode>()->axis[0], block_x);
s[expsum].bind(s[expsum]->op.as<ComputeOpNode>()->reduce_axis[0], thread_x);
s[EF].compute_at(s[expsum], s[expsum]->op.as<ComputeOpNode>()->reduce_axis[0]);
s[expsum].set_store_predicate(thread_x->var == 0);
IterVar tx, xi;
s[softmax].split_by_nparts(softmax->op.as<ComputeOpNode>()->axis[1], num_thread, &tx, &xi);
s[softmax].bind(tx, thread_x);
return s;
}
} // namespace cuda
} // namespace topi
#endif // TOPI_CUDA_SOFTMAX_H_
/*!
* Copyright (c) 2017 by Contributors
* \file array_utils.h
* \brief Utility functions for handling arrays
*/
#ifndef TOPI_DETAIL_ARRAY_UTILS_H_
#define TOPI_DETAIL_ARRAY_UTILS_H_
#include "tvm/tvm.h"
namespace topi {
namespace detail {
using namespace tvm;
/*!
* \brief Search an array for a specific item
*
* \param array The array to search
* \param item The item to search for
*
* \return True iff the given array contains the given item.
*/
template<typename T>
bool contains(Array<T> array, T item) {
for (auto& i : array) {
if (i == item) {
return true;
}
}
return false;
}
} // namespace detail
} // namespace topi
#endif // TOPI_DETAIL_ARRAY_UTILS_H_
......@@ -69,10 +69,10 @@ inline tvm::Array<tvm::Expr> InputIndexFromBroadcast(
tvm::Array<tvm::Expr> ivars;
CHECK_EQ(ovars.size(), all_vars.size());
// N^2, could use a map but NBD..
int expected_dims = T->shape.size();
for (int i = 0; i < ovars.size(); ++i) {
size_t expected_dims = T->shape.size();
for (size_t i = 0; i < ovars.size(); ++i) {
bool found = false;
for (int j = 0; j < my_vars.size(); ++j) {
for (size_t j = 0; j < my_vars.size(); ++j) {
if (all_vars[i].same_as(my_vars[j])) {
ivars.push_back(ovars[i]);
found = true;
......
/*!
* Copyright (c) 2017 by Contributors
* \file constant_utils.h
* \brief Utility functions for handling constants in TVM expressions
*/
#ifndef TOPI_DETAIL_CONSTANT_UTILS_H_
#define TOPI_DETAIL_CONSTANT_UTILS_H_
#include <string>
#include <vector>
#include "tvm/tvm.h"
namespace topi {
namespace detail {
using namespace tvm;
/*!
* \brief Test whether the given Expr is a constant integer
*
* \param expr the Expr to query
*
* \return true if the given expr is a constant int or uint, false otherwise.
*/
bool IsConstInt(Expr expr) {
return
expr->derived_from<tvm::ir::IntImm>() ||
expr->derived_from<tvm::ir::UIntImm>();
}
/*!
* \brief Get the value of the given constant integer expression. An error
* is logged if the given expression is not a constant integer.
*
* \param expr The expression to get the value of
*
* \return The integer value.
*/
int64_t GetConstInt(Expr expr) {
if (expr->derived_from<tvm::ir::IntImm>()) {
return expr.as<tvm::ir::IntImm>()->value;
}
if (expr->derived_from<tvm::ir::UIntImm>()) {
return expr.as<tvm::ir::UIntImm>()->value;
}
LOG(ERROR) << "expr must be a constant integer";
return -1;
}
/*!
* \brief Get the value of all the constant integer expressions in the given array
*
* \param exprs The array of expressions to get the values of
* \param var_name The name to be used when logging an error in the event that any
* of the expressions are not constant integers.
*
* \return A vector of the integer values
*/
std::vector<int> GetConstIntValues(Array<Expr> exprs, const std::string& var_name) {
std::vector<int> result;
for (auto expr : exprs) {
CHECK(IsConstInt(expr)) << "All elements of " << var_name << " must be constant integers";
result.push_back(GetConstInt(expr));
}
return result;
}
} // namespace detail
} // namespace topi
#endif // TOPI_DETAIL_CONSTANT_UTILS_H_
/*!
* Copyright (c) 2017 by Contributors
* \file detail/extern.h
* \brief Helpers for using external functions
*/
#ifndef TOPI_DETAIL_EXTERN_H_
#define TOPI_DETAIL_EXTERN_H_
#include <vector>
#include <string>
#include "tvm/tvm.h"
namespace topi {
namespace detail {
using namespace tvm;
/*!
* \brief Construct a buffer to pass to an external function
*
* \param shape The shape of the buffer
* \param dtype The type of the buffer elements
* \param name The name of the buffer
*
* \return The Buffer object
*/
Buffer DeclExternBuffer(Array<Expr> shape,
Type dtype,
std::string name) {
auto data = var(name, Handle());
auto elem_offset = Expr();
return BufferNode::make(data, dtype, shape, Array<Expr>(), elem_offset, name, "",
-1, 0);
}
/*!
* \brief A function which constructs an Expr representing the invocation of an external
* function. The function expects two arguments: an array of Buffers holding the input
* tensor values, and a pre-allocated array of Buffers to be filled with the outputs.
*/
using FExtern = std::function<Expr(Array<Buffer>, Array<Buffer>)>;
/*!
* \brief Create tensors representing the result of invoking an external function.
* This function will create the necessary buffers to hold input and output tensor values.
*
* \param out_shapes An array where each element is the shape of the corresponding output tensor.
* \param out_types An array where each element is the dtype of the corresponding output tensor.
* \param inputs An array of input Tensors
* \param fextern A function that constructs an Expr representing the invocation of
* the external function given the input and output buffers.
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return An array of Tensors representing the outputs of the function invocation. There will
* be one output Tensor for each element of out_shapes, with dtype equal to the corresponding
* element of out_types.
*/
Array<Tensor> make_extern(const Array< Array<Expr> >& out_shapes,
const std::vector<Type>& out_types,
const Array<Tensor>& inputs,
FExtern fextern,
std::string name,
std::string tag) {
CHECK_EQ(out_shapes.size(), out_types.size())
<< "make_extern: out_shapes and out_types must have equal size";
Array<Buffer> input_placeholders;
for (auto t : inputs) {
input_placeholders.push_back(DeclExternBuffer(t->shape, t->dtype, t->op->name));
}
Array<Buffer> output_placeholders;
for (size_t i = 0; i < out_shapes.size(); ++i) {
output_placeholders.push_back(DeclExternBuffer(out_shapes[i], out_types[i], name));
}
auto body = fextern(input_placeholders, output_placeholders);
auto body_stmt = tvm::ir::Evaluate::make(body);
auto op = ExternOpNode::make(
name, tag, inputs, input_placeholders, output_placeholders, body_stmt);
Array<Tensor> outputs;
for (size_t i = 0; i < output_placeholders.size(); ++i) {
outputs.push_back(op.output(i));
}
return outputs;
}
/*!
* \brief This function is used to create a DLTensor structure on the stack to
* be able to pass a symbolic buffer as arguments to TVM PackedFunc
*
* \param buf The buffer to pack
*
* \return An expression representing the pack operation
*/
Expr pack_buffer(Buffer buf) {
CHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element";
auto shape = tvm::ir::Call::make(Handle(), tvm::ir::intrinsic::tvm_stack_make_shape,
buf->shape, tvm::ir::Call::CallType::Intrinsic);
Expr strides;
if (buf->strides.size() > 0) {
strides = tvm::ir::Call::make(Handle(), tvm::ir::intrinsic::tvm_stack_make_shape,
buf->shape, tvm::ir::Call::CallType::Intrinsic);
} else {
strides = 0;
}
Array<Expr> pack_args{
buf->data,
shape,
strides,
make_const(Int(32), buf->shape.size()),
make_const(buf->dtype, 0),
buf->elem_offset
};
return tvm::ir::Call::make(Handle(), tvm::ir::intrinsic::tvm_stack_make_array,
pack_args, tvm::ir::Call::CallType::Intrinsic);
}
/*!
* \brief Construct an Expr representing the invocation of a PackedFunc
*
* \param args An array containing the registered name of the PackedFunc followed
* by the arguments to pass to the PackedFunc when called. The first element of the
* array must be a constant string expression.
*
* \return An expression representing the invocation
*/
Expr call_packed(Array<Expr> args) {
return tvm::ir::Call::make(Int(32), tvm::ir::intrinsic::tvm_call_packed,
args, tvm::ir::Call::CallType::Intrinsic);
}
} // namespace detail
} // namespace topi
#endif // TOPI_DETAIL_EXTERN_H_
/*!
* Copyright (c) 2017 by Contributors
* \file fuse.h
* \brief Fuse operation
*/
#ifndef TOPI_DETAIL_FUSE_H_
#define TOPI_DETAIL_FUSE_H_
#include "tvm/tvm.h"
namespace topi {
namespace detail {
using namespace tvm;
/*!
* \brief Fuse all of the given args
*
* \param stage The stage in which to apply the fuse
* \param args The iteration variables to be fused
*
* \return The fused iteration variable
*/
IterVar Fuse(Stage stage, const Array<IterVar>& args) {
CHECK_GE(args.size(), 1) << "Fuse requires at least 1 arg";
auto fused = args[0];
for (size_t i = 1; i < args.size(); ++i) {
IterVar out;
stage.fuse(fused, args[i], &out);
fused = out;
}
return fused;
}
} // namespace detail
} // namespace topi
#endif // TOPI_DETAIL_FUSE_H_
/*!
* Copyright (c) 2017 by Contributors
* \file pad_utils.h
* \brief Padding helpers
*/
#ifndef TOPI_DETAIL_PAD_UTILS_H_
#define TOPI_DETAIL_PAD_UTILS_H_
#include <vector>
#include "tvm/tvm.h"
namespace topi {
namespace detail {
using namespace tvm;
/*!
* \brief Get padding size for each side given padding height and width
*
* \param pad_h The amount to pad each of the top and bottom sides
* \param pad_w The amount to pad each of the left and right sides
*
* \return An array of 4 elements, representing padding sizes for
* each individual side. The array is in the order { top, left, bottom, right }
*/
Array<Expr> GetPadTuple(Expr pad_h, Expr pad_w) {
pad_h *= 2;
pad_w *= 2;
auto pad_top = (pad_h + 1) / 2;
auto pad_left = (pad_w + 1) / 2;
return { pad_top, pad_left, pad_h - pad_top, pad_w - pad_left };
}
} // namespace detail
} // namespace topi
#endif // TOPI_DETAIL_PAD_UTILS_H_
/*!
* Copyright (c) 2017 by Contributors
* \file ravel_unravel.h
* \brief Index ravel and unraval operations
*/
#ifndef TOPI_DETAIL_RAVEL_UNRAVEL_H_
#define TOPI_DETAIL_RAVEL_UNRAVEL_H_
#include <vector>
#include "tvm/tvm.h"
namespace topi {
namespace detail {
using namespace tvm;
/*!
* \brief Flatten the indices to 1D
*
* \param indices The input coordinates
* \param shape Shape of the tensor
*
* \return The index after flattening
*/
inline Expr RavelIndex(Array<Var> indices, Array<Expr> shape) {
CHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size";
CHECK_GT(indices.size(), 0) << "indices must not be empty";
Expr idx;
for (size_t i = 0; i < indices.size(); ++i) {
if (i == 0) {
idx = indices[i];
} else {
idx = idx * shape[i] + indices[i];
}
}
return idx;
}
/*!
* \brief Convert flattened index to coordinate array
*
* \param idx The 1D index
* \param shape Shape of the tensor
*
* \return The coordinate corresponding to the 1D index
*/
inline Array<Expr> UnavelIndex(Expr idx, Array<Expr> shape) {
std::vector<Expr> indices;
for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
indices.push_back(idx % shape[i]);
idx = idx / shape[i];
}
std::reverse(indices.begin(), indices.end());
return indices;
}
} // namespace detail
} // namespace topi
#endif // TOPI_DETAIL_RAVEL_UNRAVEL_H_
......@@ -28,6 +28,144 @@ TOPI_DECLARE_UNARY_OP(exp);
TOPI_DECLARE_UNARY_OP(tanh);
TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP(log);
/*!
* \brief Creates an operation that returns identity of a given tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the identity operation
*/
inline Tensor identity(const Tensor& x,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
return x(i);
}, name, tag);
}
/*!
* \brief Creates an operation that returns the negation of a given tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the negation operation
*/
inline Tensor negative(const Tensor& x,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
return -x(i);
}, name, tag);
}
/*!
* \brief Creates an operation that raises each element of tensor x to power y
*
* \param x The input tensor
* \param y The exponent
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the pow operation
*/
inline Tensor pow(const Tensor& x,
const Expr& y,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
return tvm::pow(x(i), y);
}, name, tag);
}
/*!
* \brief Creates an operation that performs pointwise left shift by n bits
*
* \param x The input tensor
* \param n The number of bits to shift by
*
* \return A Tensor whose op member is the left shift operation
*/
inline Tensor operator<<(const Tensor& x,
const Expr& n) {
return compute(x->shape, [&](const Array<Var>& i) {
return x(i) << n;
}, "tensor", kElementWise);
}
/*!
* \brief Creates an operation that performs pointwise right shift by n bits
*
* \param x The input tensor
* \param n The number of bits to shift by
*
* \return A Tensor whose op member is the right shift operation
*/
inline Tensor operator>>(const Tensor& x,
const Expr& n) {
return compute(x->shape, [&](const Array<Var>& i) {
return x(i) >> n;
}, "tensor", kElementWise);
}
/*!
* \brief Creates an operation that clips each element of a tensor to
* the interval [a_min, a_max]
*
* \param x The input tensor
* \param a_min The inclusive lower bound of the interval
* \param a_max The inclusive upper bound of the interval
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the clip operation
*/
inline Tensor clip(const Tensor& x,
const Expr& a_min,
const Expr& a_max,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
auto min_val = tvm::cast(x->dtype, a_min);
auto max_val = tvm::cast(x->dtype, a_max);
return tvm::max(tvm::min(x(i), max_val), min_val); // NOLINT(*)
}, name, tag);
}
/*!
* \brief Cast each element of x to the given type. If expr is
* scalar and type is a corresponding vector type, a
* Broadcast is generated, otherwise a Cast is generated.
*
* \param x The input tensor
* \param type The type to cast to
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the cast operation
*/
inline Tensor cast(const Tensor& x,
Type type,
std::string name = "tensor",
std::string tag = kElementWise) {
return compute(x->shape, [&](const Array<Var>& i) {
auto expr = x(i);
if (expr.type().code() == type.code() && expr.type().bits() == type.bits()) {
if (expr.type().lanes() == type.lanes()) {
return expr;
} else if (expr.type().lanes() == 1 && type.lanes() > 1) {
return tvm::ir::Broadcast::make(expr, type.lanes());
}
}
return tvm::cast(type, x(i));
}, name, tag);
}
} // namespace topi
#endif // TOPI_ELEMWISE_H_
/*!
* Copyright (c) 2017 by Contributors
* \file generic/default.h
* \brief Generic default schedule
*/
#ifndef TOPI_GENERIC_DEFAULT_H_
#define TOPI_GENERIC_DEFAULT_H_
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/tvm.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
namespace generic {
/*!
* \brief Create a generic default schedule for the given output tensors.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
* \param auto_inline Whether to apply the auto inline step.
*
* \return A schedule for the given ops.
*/
Schedule default_schedule(const Target& target, Array<Tensor> outs, bool auto_inline) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
if (auto_inline) {
auto x = outs[0];
tvm::schedule::AutoInlineInjective(s);
auto axis = s[x]->op.as<ComputeOpNode>()->axis;
if (axis.size() > 0) {
Fuse(s[x], axis);
}
}
return s;
}
} // namespace generic
} // namespace topi
#endif // TOPI_GENERIC_DEFAULT_H_
/*!
* Copyright (c) 2017 by Contributors
* \file generic/extern.h
* \brief Schedule for extern followed by injective ops
*/
#ifndef TOPI_GENERIC_EXTERN_H_
#define TOPI_GENERIC_EXTERN_H_
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/tvm.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
namespace generic {
/*!
* \brief Schedule an extern op followed by injective operations
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the op.
*/
Schedule schedule_extern(const Target& target, Array<Tensor> outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
return s;
}
} // namespace generic
} // namespace topi
#endif // TOPI_GENERIC_EXTERN_H_
/*!
* Copyright (c) 2017 by Contributors
* \file generic/injective.h
* \brief Generic schedule for injective operations
*/
#ifndef TOPI_GENERIC_INJECTIVE_H_
#define TOPI_GENERIC_INJECTIVE_H_
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/tvm.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
namespace generic {
/*!
* \brief Create a generic schedule for the given injective ops.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
Schedule schedule_injective(const Target &target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
tvm::schedule::AutoInlineInjective(s);
auto x = outs[0];
Fuse(s[x], s[x]->op.as<ComputeOpNode>()->axis);
return s;
}
} // namespace generic
} // namespace topi
#endif // TOPI_GENERIC_INJECTIVE_H_
......@@ -21,7 +21,7 @@ template <typename T>
tvm::Expr Map(const tvm::Array<tvm::Expr>& exprs, T op) {
CHECK_GE(exprs.size(), 1);
tvm::Expr res = exprs[0];
for (int i = 1; i < exprs.size(); ++i) {
for (size_t i = 1; i < exprs.size(); ++i) {
res = op(res, exprs[i]);
}
return res;
......@@ -52,6 +52,34 @@ inline tvm::Tensor relu(const tvm::Tensor& t,
}
/*!
* \brief Creates an operation that performs a leaky rectified linear unit
*
* \param t The input tensor
* \param threshold The relu threshold (default 0)
* \param alpha The slope for the small gradient when t < threshold
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the relu operation
*/
template <typename T>
inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
T threshold = static_cast<T>(0),
T alpha = static_cast<T>(0.1),
std::string name = "tensor",
std::string tag = kElementWise) {
return tvm::compute(
t->shape,
[&](const tvm::Array<tvm::Var>& i) {
auto value = t(i);
auto calpha = tvm::make_const(value.type(), alpha);
return tvm::select(value > 0, value, value * alpha);
},
name,
tag);
}
/*!
* \brief Creates an operation that performs padding
*
* \param t The input tensor
......@@ -59,10 +87,11 @@ inline tvm::Tensor relu(const tvm::Tensor& t,
* respective iterator
* \param pad_after An Array of Expr describing the padding after the
* respective iterator
* \param pad_value The value to fill padding elements with
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the relu operation
* \return A Tensor whose op member is the padding operation
*
* \note
* The pad_after Array must either be empty or have the same length as
......@@ -86,17 +115,18 @@ inline tvm::Tensor relu(const tvm::Tensor& t,
inline tvm::Tensor pad(const tvm::Tensor& t,
const tvm::Array<tvm::Expr>& pad_before,
tvm::Array<tvm::Expr> pad_after = tvm::Array<tvm::Expr>(),
Expr pad_value = Expr(),
std::string name = "tensor",
std::string tag = kElementWise) {
if (pad_after.size() < pad_before.size()) {
for (int i = pad_after.size(); i < pad_before.size(); ++i) {
for (size_t i = pad_after.size(); i < pad_before.size(); ++i) {
pad_after.push_back(pad_before[i]);
}
}
CHECK_GE(pad_before.size(), 1);
CHECK_EQ(pad_before.size(), pad_after.size());
tvm::Array<tvm::Expr> output_shape;
for (int i = 0; i < t->shape.size(); ++i) {
for (size_t i = 0; i < t->shape.size(); ++i) {
if (i >= pad_before.size()) {
output_shape.push_back(t->shape[i]);
} else {
......@@ -104,10 +134,15 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
tvm::ir::Simplify(t->shape[i] + pad_before[i] + pad_after[i]));
}
}
if (!pad_value.defined()) {
pad_value = tvm::make_const(t->dtype, 0);
}
auto l = [&](tvm::Array<tvm::Var> ovars) {
tvm::Array<tvm::Expr> indices;
tvm::Array<tvm::Expr> sel;
for (int i = 0; i < t->shape.size(); ++i) {
for (size_t i = 0; i < t->shape.size(); ++i) {
if (i >= pad_before.size()) {
indices.push_back(ovars[i]);
continue;
......@@ -122,7 +157,10 @@ inline tvm::Tensor pad(const tvm::Tensor& t,
sel.push_back(tvm::ir::Simplify(ovars[i] < pad_before[i] + t->shape[i]));
}
}
return tvm::select(detail::Map(sel, tvm::ir::And::make), t(indices), 0);
if (sel.size() != 0) {
return tvm::select(detail::Map(sel, tvm::ir::And::make), t(indices), pad_value);
}
return t(indices);
};
return tvm::compute(output_shape, l, name, tag);
}
......
/*!
* Copyright (c) 2017 by Contributors
* \brief Batch normalization op constructions
* \file nn/batch_norm.h
*/
#ifndef TOPI_NN_BATCH_NORM_H_
#define TOPI_NN_BATCH_NORM_H_
#include <string>
#include "topi/tags.h"
#include "tvm/tvm.h"
namespace topi {
namespace nn {
using namespace tvm;
/*!
* \brief Batch normalization inference operator with NCHW layout
*
* \param x The input tensor. 4-D with shape [batch, channel, height, width]
* \param gamma 1-D with shape [channel]
* \param beta 1-D with shape [channel]
* \param moving_mean 1-D with shape [channel]
* \param moving_var 1-D with shape [channel]
* \param eps Epsilon to prevent div by 0
* \param fix_gamma Fix gamma while training
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the batch normalization operation
*/
inline Tensor batch_norm_inference(const Tensor& x,
const Tensor& gamma,
const Tensor& beta,
const Tensor& moving_mean,
const Tensor& moving_var,
float eps,
bool fix_gamma,
std::string name = "tensor",
std::string tag = kBroadcast) {
CHECK_EQ(x->shape.size(), 4) << "Batch norm requires 4-D input";
Tensor out;
if (fix_gamma) {
out = tvm::compute(
x->shape,
[&](const Array<Var>& indices) {
auto c = Array<Var>({ indices[1] });
return (x(indices) - moving_mean(c)) / tvm::sqrt(moving_var(c) + eps) + beta(c);
}, name, tag);
} else {
out = tvm::compute(
x->shape,
[&](const Array<Var>& indices) {
auto c = Array<Var>({ indices[1] });
return (x(indices) - moving_mean(c)) / tvm::sqrt(moving_var(c) + eps) * gamma(c) + beta(c);
}, name, tag);
}
return out;
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_BATCH_NORM_H_
/*!
* Copyright (c) 2017 by Contributors
* \brief Binary op constructions
* \file nn/bnn.h
*/
#ifndef TOPI_NN_BNN_H_
#define TOPI_NN_BNN_H_
#include <string>
#include "tvm/tvm.h"
#include "tvm/ir_pass.h"
#include "topi/tags.h"
#include "topi/detail/constant_utils.h"
namespace topi {
namespace nn {
using namespace tvm;
/*!
* \brief Binarization and bit-packing along a certain axis.
*
* \param data N-D tensor, can be any layout
* \param axis The axis along which to do binarization and bit-packing. This axis
* must have a size equal to an integer multiple of 32.
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return Output tensor with dtype uint32
*/
inline tvm::Tensor binarize_pack(const tvm::Tensor& data,
int axis,
std::string name = "PackedInput",
std::string tag = "binarize_pack") {
auto ishape = data->shape;
CHECK_EQ(GetConstInt(ishape[axis]) % 32, 0)
<< "binarize_pack: axis size must be a multiple of 32";
auto n = ishape.size();
Array<Expr> oshape;
for (size_t i = 0; i < n; ++i) {
oshape.push_back(i == static_cast<size_t>(axis) ?
tvm::ir::Simplify(ishape[i] / 32) :
ishape[i]);
}
return tvm::compute(
oshape,
[&](const Array<Var>& indices) {
Array<Expr> start_idx;
for (size_t i = 0; i < n; ++i) {
start_idx.push_back(i == static_cast<size_t>(axis) ?
indices[i] * 32 :
static_cast<Expr>(indices[i]));
}
auto packed = make_const(UInt(32), 0);
for (size_t j = 0; j < 32; ++j) {
Array<Expr> idx;
for (size_t i = 0; i < n; ++i) {
idx.push_back(i == static_cast<size_t>(axis) ?
start_idx[i] + static_cast<int>(j) :
start_idx[i]);
}
auto sign = tvm::cast(UInt(32), data(idx) >= 0);
packed = (packed | sign);
if (j == 31) {
return packed;
}
packed = packed << 1;
}
return packed; // never reached, but suppress compiler warning
}, name, tag);
}
/*!
* \brief Binary matrix multiplication using xor and bit-count
*
* \param data Tensor with shape [batch, in_dim], dtype is uint32
* \param weight Tensor with shape [out_dim, in_dim], dtype is uint32
*
* \return Tensor with shape [batch, out_dim], dtype is float32
*/
inline tvm::Tensor binary_dense(const tvm::Tensor& data,
const tvm::Tensor& weight) {
CHECK_EQ(data->shape.size(), 2) << "binary_dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "binary_dense requires 2-D weight";
CHECK_EQ(data->dtype, UInt(32)) << "binary_dense requires uint32 data";
CHECK_EQ(weight->dtype, UInt(32)) << "binary_dense requires uint32 weight";
auto batch = data->shape[0];
auto in_dim = data->shape[1];
auto out_dim = weight->shape[0];
auto k = tvm::reduce_axis(Range(0, in_dim), "k");
auto matmul = tvm::compute(
{ batch, out_dim },
[&](Var i, Var j) {
return tvm::sum(popcount(data(i, k) ^ weight(j, k)), { k });
}, "tensor", "binary_dense");
return tvm::compute(
{ batch, out_dim },
[&](Var i, Var j) {
return 32 * in_dim - 2.0f * matmul(i, j);
}, "tensor", kElementWise);
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_BNN_H_
/*!
* Copyright (c) 2017 by Contributors
* \brief Dense op constructions
* \file nn/dense.h
*/
#ifndef TOPI_NN_DENSE_H_
#define TOPI_NN_DENSE_H_
#include <string>
#include "topi/tags.h"
#include "tvm/tvm.h"
namespace topi {
namespace nn {
using namespace tvm;
/*!
* \brief Creates an operation that calculates data * weight^T + bias
*
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim] (optional)
*
* \return Tensor with shape [batch, out_dim]
*/
inline tvm::Tensor dense(const tvm::Tensor& data,
const tvm::Tensor& weight,
tvm::Tensor* bias) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias != nullptr) {
CHECK_EQ((*bias)->shape.size(), 1) << "dense requires 1-D bias";
}
auto batch = data->shape[0];
auto in_dim = data->shape[1];
auto out_dim = weight->shape[0];
auto k = tvm::reduce_axis(Range(0, in_dim), "k");
auto matmul = tvm::compute(
{ batch, out_dim },
[&](Var i, Var j) {
return tvm::sum(data(i, k) * weight(j, k), { k });
}, "tensor", "dense");
if (bias != nullptr) {
auto bias_val = *bias;
matmul = tvm::compute(
{ batch, out_dim },
[&](Var i, Var j) {
return matmul(i, j) + bias_val(j);
}, "tensor", kBroadcast);
}
return matmul;
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_DENSE_H_
/*!
* Copyright (c) 2017 by Contributors
* \brief Dilate op constructions
* \file nn/dilate.h
*/
#ifndef TOPI_NN_DILATE_H_
#define TOPI_NN_DILATE_H_
#include <string>
#include "tvm/tvm.h"
#include "tvm/ir_pass.h"
#include "topi/tags.h"
namespace topi {
namespace nn {
using namespace tvm;
/*!
* \brief Create a new expression of the logical and of all
* conditions in the arguments.
*
* \param args The arguments to find the logical conjunction of
*
* \return The logical conjunction expression
*/
Expr all(Array<Expr> args) {
CHECK_GT(args.size(), 0) << "all requires at least one argument";
Expr ret = args[0];
for (size_t i = 1; i < args.size(); ++i) {
ret = ret && args[i];
}
return ret;
}
/*!
* \brief Dilate data with zeros
*
* \param x The input tensor, this can have any number of
* dimensions and any layout.
* \param strides Dilation stride for each dimension. Stride 1
* means no dilation.
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The output tensor.
*/
inline Tensor dilate(const Tensor& x,
Array<Expr> strides,
std::string name = "tensor",
std::string tag = kInjective) {
auto n = x->shape.size();
CHECK_EQ(n, strides.size())
<< "strides size (" << strides.size()
<< ") must match dimension of x (" << n << ")";
Array<Expr> out_shape;
for (size_t i = 0; i < n; ++i) {
out_shape.push_back(tvm::ir::Simplify(
(x->shape[i] - 1) * strides[i] + 1));
}
return tvm::compute(
out_shape,
[&](const Array<Var>& indices) {
Array<Expr> not_zero;
Array<Expr> index_tuple;
for (size_t i = 0; i < n; ++i) {
if (IsConstInt(strides[i]) && GetConstInt(strides[i]) == 1) {
index_tuple.push_back(indices[i]);
} else {
index_tuple.push_back(indices[i] / strides[i]);
not_zero.push_back((indices[i] % strides[i]) == 0);
}
}
if (not_zero.size() > 0) {
auto all_not_zero = all(not_zero);
return tvm::select(all_not_zero, x(index_tuple), make_const(x->dtype, 0));
}
return x(index_tuple);
}, name, tag);
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_DILATE_H_
/*!
* Copyright (c) 2017 by Contributors
* \brief Softmax op constructions
* \file nn/flatten.h
*/
#ifndef TOPI_NN_FLATTEN_H_
#define TOPI_NN_FLATTEN_H_
#include <string>
#include <vector>
#include "topi/tags.h"
#include "topi/detail/constant_utils.h"
#include "tvm/tvm.h"
namespace topi {
namespace nn {
using namespace tvm;
/*!
* \brief Flattens the input tensor into a 2-D tensor by collapsing higher dimensions.
* This requires the input tensor to have constant sized dimensions.
*
* \param x The input tensor.
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A 2-D tensor.
*/
inline Tensor flatten(const Tensor& x,
std::string name = "tensor",
std::string tag = kInjective) {
auto ishape = x->shape;
int dim = 1;
for (size_t i = 1; i < ishape.size(); ++i) {
dim = dim * static_cast<int>(GetConstInt(ishape[i]));
}
Array<Expr> oshape({ ishape[0], dim });
std::vector<Expr> extra_shape;
for (size_t i = 1; i < ishape.size(); ++i) {
extra_shape.push_back(ishape[i]);
}
std::reverse(extra_shape.begin(), extra_shape.end());
return tvm::compute(
oshape, [&](Var i, Var j) {
Expr idx = j;
std::vector<Expr> index;
for (auto s : extra_shape) {
index.push_back(idx % s);
idx = idx / s;
}
index.push_back(i);
std::reverse(index.begin(), index.end());
return x(index);
});
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_FLATTEN_H_
/*!
* Copyright (c) 2017 by Contributors
* \brief Mapping op constructions
* \file nn/mapping.h
*/
#ifndef TOPI_NN_MAPPING_H_
#define TOPI_NN_MAPPING_H_
#include <string>
#include "topi/tags.h"
#include "tvm/tvm.h"
namespace topi {
namespace nn {
using namespace tvm;
/*!
* \brief Scale and shift with NCHW order
*
* \param x The input tensor.
* \param scale Scale tensor, 1-D of size channel
* \param shift Shift tensor, 1-D of size channel
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the scale shift operation
*/
inline Tensor scale_shift_nchw(const Tensor& x,
const Tensor& scale,
const Tensor& shift,
std::string name = "ScaleShift",
std::string tag = kBroadcast) {
return tvm::compute(
x->shape,
[&](Var b, Var c, Var h, Var w) {
return x(b, c, h, w) * scale(c) + shift(w);
}, name, tag);
}
/*!
* \brief Scale and shift with NHWC order
*
* \param x The input tensor.
* \param scale Scale tensor, 1-D of size channel
* \param shift Shift tensor, 1-D of size channel
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the scale shift operation
*/
inline Tensor scale_shift_nhwc(const Tensor& x,
const Tensor& scale,
const Tensor& shift,
std::string name = "ScaleShift",
std::string tag = kBroadcast) {
return tvm::compute(
x->shape,
[&](Var b, Var h, Var w, Var c) {
return x(b, h, w, c) * scale(c) + shift(w);
}, name, tag);
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_MAPPING_H_
/*!
* Copyright (c) 2017 by Contributors
* \brief Pooling op constructions
* \file nn/pooling.h
*/
#ifndef TOPI_NN_POOLING_H_
#define TOPI_NN_POOLING_H_
#include <string>
#include "tvm/tvm.h"
#include "tvm/ir_pass.h"
#include "topi/tags.h"
#include "topi/detail/pad_utils.h"
#include "topi/nn.h"
namespace topi {
namespace nn {
using namespace tvm;
/*! \brief Pooling type */
enum PoolType : int {
kAvgPool,
kMaxPool,
};
/*!
* \brief Perform pooling on data in NCHW order
*
* \param x The input tensor in NCHW order
* \param kernel_size Vector of two ints: {kernel_height, kernel_width}
* \param stride_size Vector of two ints: {stride_height, stride_width}
* \param padding_size Vector of two ints: {padding_height, padding_width}
* \param pool_type The type of pooling operator
* \param ceil_mode Whether to use ceil when calculating the output size
*
* \return The output tensor in NCHW order
*/
inline Tensor pool(const Tensor& x,
const Array<Expr>& kernel_size,
const Array<Expr>& stride_size,
const Array<Expr>& padding_size,
PoolType pool_type,
bool ceil_mode) {
CHECK_EQ(x->shape.size(), 4) << "Pooling input must be 4-D";
CHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
CHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
CHECK_EQ(padding_size.size(), 2) << "Pooling padding_size must have 2 elements";
auto kernel_height = kernel_size[0];
auto kernel_width = kernel_size[1];
auto stride_height = stride_size[0];
auto stride_width = stride_size[1];
auto padding_height = padding_size[0];
auto padding_width = padding_size[1];
auto batch = x->shape[0];
auto channel = x->shape[1];
auto height = x->shape[2];
auto width = x->shape[3];
auto pad_tuple = detail::GetPadTuple(padding_height, padding_width);
auto pad_top = pad_tuple[0];
auto pad_left = pad_tuple[1];
auto pad_down = pad_tuple[2];
auto pad_right = pad_tuple[3];
if (ceil_mode) {
// Additional padding to ensure we do ceil instead of floor when
// dividing by stride.
pad_down += stride_height - 1;
pad_right += stride_width - 1;
}
Array<Expr> pad_before{ 0, 0, pad_top, pad_left };
Array<Expr> pad_after{ 0, 0, pad_down, pad_right };
auto out_height = tvm::ir::Simplify(
(height - kernel_height + pad_top + pad_down) / stride_height + 1);
auto out_width = tvm::ir::Simplify(
(width - kernel_width + pad_left + pad_right) / stride_width + 1);
auto dheight = tvm::reduce_axis(Range(0, kernel_height));
auto dwidth = tvm::reduce_axis(Range(0, kernel_width));
if (pool_type == kMaxPool) {
auto temp = pad(x, pad_before, pad_after, x->dtype.min(), "pad_temp");
return tvm::compute(
{ batch, channel, out_height, out_width },
[&](Var n, Var c, Var h, Var w) {
return tvm::max(temp(n, c, h * stride_height + dheight, w * stride_width + dwidth),
{ dheight, dwidth });
}, "tensor", "pool_max");
} else if (pool_type == kAvgPool) {
auto temp = pad(x, pad_before, pad_after, 0, "pad_temp");
auto tsum = tvm::compute(
{ batch, channel, out_height, out_width },
[&](Var n, Var c, Var h, Var w) {
return tvm::sum(temp(n, c, h * stride_height + dheight, w * stride_width + dwidth),
{ dheight, dwidth });
}, "tensor", "pool_avg");
return tvm::compute(
{ batch, channel, out_height, out_width },
[&](Var n, Var c, Var h, Var w) {
return tsum(n, c, h, w) / (kernel_height * kernel_width);
}, "tensor", kElementWise);
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
return x;
}
}
/*!
* \brief Perform global pooling on data in NCHW order
*
* \param x The input tensor in NCHW order
* \param pool_type The type of pooling operator
*
* \return The output tensor with shape [batch, channel, 1, 1]
*/
inline Tensor global_pool(const Tensor& x,
PoolType pool_type) {
CHECK_EQ(x->shape.size(), 4) << "Pooling input must be 4-D";
auto batch = x->shape[0];
auto channel = x->shape[1];
auto height = x->shape[2];
auto width = x->shape[3];
auto dheight = tvm::reduce_axis(Range(0, height));
auto dwidth = tvm::reduce_axis(Range(0, width));
if (pool_type == kMaxPool) {
return tvm::compute(
{ batch, channel, 1, 1 },
[&](Var n, Var c, Var h, Var w) {
return tvm::max(x(n, c, dheight, dwidth), { dheight, dwidth }); // NOLINT(*)
}, "tensor", "global_pool_max");
} else if (pool_type == kAvgPool) {
auto tsum = tvm::compute(
{ batch, channel, 1, 1 },
[&](Var n, Var c, Var h, Var w) {
return tvm::sum(x(n, c, dheight, dwidth), { dheight, dwidth });
}, "tensor", "global_pool_sum");
return tvm::compute(
{ batch, channel, 1, 1 },
[&](Var n, Var c, Var h, Var w) {
return tsum(n, c, h, w) / tvm::cast(x->dtype, height * width);
}, "tensor", kElementWise);
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
return x;
}
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_POOLING_H_
/*!
* Copyright (c) 2017 by Contributors
* \brief Softmax op constructions
* \file nn/softmax.h
*/
#ifndef TOPI_NN_SOFTMAX_H_
#define TOPI_NN_SOFTMAX_H_
#include <algorithm>
#include <string>
#include "topi/tags.h"
#include "tvm/tvm.h"
namespace topi {
namespace nn {
using namespace tvm;
/*!
* \brief Softmax activation
*
* \param x The input tensor. 2-D where softmax is performed along the second dimension
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the softmax operation
*/
inline Tensor softmax(const Tensor& x,
std::string name = "tensor",
std::string tag = "softmax_output") {
CHECK_EQ(x->shape.size(), 2) << "Softmax requires 2-D input";
Expr m = x->shape[0];
Expr n = x->shape[1];
auto k = tvm::reduce_axis(Range(0, n), "k");
auto max_elem = tvm::compute(
{ m }, [&](Var i) {
return tvm::max(x(i, k), Array<IterVar>{ k }); });
k = tvm::reduce_axis(Range(0, n), "k");
auto expsum = tvm::compute(
{ m }, [&](Var i) {
return tvm::sum(tvm::exp(x(i, k) - max_elem(i)), { k }); });
return tvm::compute(
x->shape, [&](Var i, Var j) {
return tvm::exp(x(i, j) - max_elem(i)) / expsum(i);
});
}
/*!
* \brief Log softmax activation
*
* \param x The input tensor. 2-D where log softmax is performed along the second dimension
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the log softmax operation
*/
inline Tensor log_softmax(const Tensor& x,
std::string name = "tensor",
std::string tag = "log_softmax_output") {
CHECK_EQ(x->shape.size(), 2) << "Log softmax requires 2-D input";
Expr m = x->shape[0];
Expr n = x->shape[1];
auto k = tvm::reduce_axis(Range(0, n), "k");
auto max_elem = tvm::compute(
{ m }, [&](Var i) {
return tvm::max(x(i, k), Array<IterVar>{ k }); });
k = tvm::reduce_axis(Range(0, n), "k");
auto expsum = tvm::compute(
{ m }, [&](Var i) {
return tvm::sum(tvm::exp(x(i, k) - max_elem(i)), { k }); });
return tvm::compute(
x->shape, [&](Var i, Var j) {
return x(i, j) - max_elem(i) - tvm::log(expsum(i));
});
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_SOFTMAX_H_
/*!
* Copyright (c) 2017 by Contributors
* \file rocm/dense.h
* \brief rocm schedule for dense operation
*/
#ifndef TOPI_ROCM_DENSE_H_
#define TOPI_ROCM_DENSE_H_
#include "tvm/tvm.h"
#include "tvm/build_module.h"
#include "topi/tags.h"
#include "topi/detail/array_utils.h"
#include "topi/nn/dense.h"
#include "topi/contrib/rocblas.h"
#include "topi/generic/extern.h"
#include "topi/cuda/dense.h"
namespace topi {
using namespace tvm;
namespace rocm {
/*!
* \brief Implementation of dense for rocm backend
*
* \param target The target device
* \param data Tensor with shape [batch, in_dim]
* \param weight Tensor with shape [out_dim, in_dim]
* \param bias Tensor with shape [out_dim] (optional)
*
* \return Tensor with shape [batch, out_dim]
*/
inline tvm::Tensor dense_rocm(const Target& target,
const tvm::Tensor& data,
const tvm::Tensor& weight,
tvm::Tensor* bias) {
CHECK_EQ(data->shape.size(), 2) << "dense requires 2-D data";
CHECK_EQ(weight->shape.size(), 2) << "dense requires 2-D weight";
if (bias != nullptr) {
CHECK_EQ((*bias)->shape.size(), 1) << "dense requires 1-D bias";
}
auto batch = data->shape[0];
auto in_dim = data->shape[1];
auto out_dim = weight->shape[0];
if (target.libs.count("rocblas") > 0) {
auto mm = topi::contrib::rocblas_matmul(data, weight, false, true);
if (bias != nullptr) {
auto bias_val = *bias;
mm = tvm::compute({ batch, out_dim },
[&](Var i, Var j) {
return mm(i, j) + bias_val(j);
}, "tensor", kBroadcast);
}
return mm;
} else {
return topi::nn::dense(data, weight, bias);
}
}
/*!
* \brief Create a rocm schedule for dense
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
Schedule schedule_dense(const Target &target, const Array<Tensor>& outs) {
if (target.target_name == "rocm" &&
target.libs.count("rocblas") > 0) {
return topi::generic::schedule_extern(target, outs);
}
return topi::cuda::schedule_dense(target, outs);
}
} // namespace rocm
} // namespace topi
#endif // TOPI_ROCM_DENSE_H_
......@@ -6,9 +6,14 @@
#ifndef TOPI_TAGS_H_
#define TOPI_TAGS_H_
#include <string>
namespace topi {
constexpr auto kElementWise = "elemwise";
constexpr auto kInjective = "injective";
constexpr auto kCommReduce = "comm_reduce";
constexpr auto kCommReduceIdx = "comm_reduce_idx";
constexpr auto kBroadcast = "broadcast";
constexpr auto kMatMult = "matmult";
constexpr auto kConv2dNCHW = "conv2d_nchw";
......@@ -19,6 +24,19 @@ constexpr auto kDepthwiseConv2dBackInputNHWC = "depthwise_conv2d_back_input_nhwc
constexpr auto kDepthwiseConv2dBackWeightNHWC = "depthwise_conv2d_back_weight_nhwc";
constexpr auto kGroupConv2d = "group_conv2d";
inline bool is_broadcast(std::string tag) {
return
tag.rfind(kElementWise, 0) == 0 ||
tag.rfind(kBroadcast, 0) == 0;
}
inline bool is_injective(std::string tag) {
return
tag.rfind(kElementWise, 0) == 0 ||
tag.rfind(kBroadcast, 0) == 0 ||
tag.rfind(kInjective, 0) == 0;
}
} // namespace topi
#endif // TOPI_TAGS_H_
/*!
* Copyright (c) 2017 by Contributors
* \file x86/bnn.h
* \brief x86 schedule for binary operations
*/
#ifndef TOPI_X86_BNN_H_
#define TOPI_X86_BNN_H_
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/tvm.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
namespace x86 {
/*!
* \brief Create a generic schedule for binarize_pack
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
Schedule schedule_binarize_pack(const Target &target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
auto _schedule = [&](const Tensor& out) {
s[out].parallel(out->op.as<ComputeOpNode>()->axis[0]);
};
std::function<void(Operation)> traverse;
traverse = [&](const Operation& op) {
if (op->tag == "binarize_pack") {
_schedule(op.output(0));
} else {
LOG(ERROR) << "Unsupported operator " << op->tag;
}
};
traverse(outs[0]->op);
return s;
}
/*!
* \brief Create a generic schedule for binary_dense
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
Schedule schedule_binary_dense(const Target &target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
auto _schedule = [&](const Tensor& A, const Tensor& B, const Tensor& C) {
IterVar co, ci;
s[C].split(s[C]->op.as<ComputeOpNode>()->reduce_axis[0], 8, &co, &ci);
s[C].parallel(s[C]->op.as<ComputeOpNode>()->axis[0]);
Tensor out;
if (contains(s->outputs, C->op)) {
out = C;
} else {
out = outs[0]->op.output(0);
}
IterVar xo, xi;
s[out].split(out->op.as<ComputeOpNode>()->axis[1], 8, &xo, &xi);
s[out].vectorize(xi);
};
std::function<void(Operation)> traverse;
traverse = [&](const Operation& op) {
// Inline all one-to-one-mapping operators except the last stage (output)
if (is_broadcast(op->tag)) {
if (!contains(s->outputs, op)) {
s[op].compute_inline();
}
for (auto tensor : op->InputTensors()) {
if (tensor->op->InputTensors().size() > 0) {
traverse(tensor->op);
}
}
} else if (op->tag == "binary_dense") {
auto output = op.output(0);
auto data = op->InputTensors()[0];
auto weight = op->InputTensors()[1];
_schedule(data, weight, output);
} else {
LOG(ERROR) << "Unsupported operator " << op->tag;
}
};
traverse(outs[0]->op);
return s;
}
} // namespace x86
} // namespace topi
#endif // TOPI_X86_BNN_H_
/*!
* Copyright (c) 2017 by Contributors
* \file x86/default.h
* \brief default x86 schedule
*/
#ifndef TOPI_X86_DEFAULT_H_
#define TOPI_X86_DEFAULT_H_
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/tvm.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
namespace x86 {
/*!
* \brief Create a default x86 schedule for the given ops.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
* \param auto_inline Whether to apply the auto inline step.
*
* \return A schedule for the given ops.
*/
Schedule default_schedule(const Target &target, const Array<Tensor>& outs, bool auto_inline) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
auto x = outs[0];
auto axis = s[x]->op.as<ComputeOpNode>()->axis;
if (auto_inline) {
tvm::schedule::AutoInlineInjective(s);
if (axis.size() > 0) {
Fuse(s[x], axis);
}
return s;
}
if (axis.size() == 4) {
auto n = axis[0];
auto c = axis[1];
auto fused = Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h
s[x].parallel(fused);
} else {
s[x].parallel(axis[0]);
}
return s;
}
} // namespace x86
} // namespace topi
#endif // TOPI_X86_DEFAULT_H_
/*!
* Copyright (c) 2017 by Contributors
* \file x86/injective.h
* \brief x86 schedule for injective ops
*/
#ifndef TOPI_X86_INJECTIVE_H_
#define TOPI_X86_INJECTIVE_H_
#include "topi/tags.h"
#include "topi/detail/fuse.h"
#include "tvm/tvm.h"
#include "tvm/build_module.h"
namespace topi {
using namespace tvm;
namespace x86 {
/*!
* \brief Create an x86 schedule for the given injective ops.
*
* \param target The target to generate a schedule for.
* \param outs The output tensors.
*
* \return A schedule for the given ops.
*/
Schedule schedule_injective(const Target &target, const Array<Tensor>& outs) {
Array<Operation> out_ops;
for (auto t : outs) {
out_ops.push_back(t->op);
}
auto s = create_schedule(out_ops);
tvm::schedule::AutoInlineInjective(s);
auto x = outs[0];
auto axis = s[x]->op.as<ComputeOpNode>()->axis;
if (axis.size() == 4) {
auto n = axis[0];
auto c = axis[1];
auto fused = Fuse(s[x], { n, c }); // for nhwc layout, fuse n and h
s[x].parallel(fused);
} else {
s[x].parallel(axis[0]);
}
return s;
}
} // namespace x86
} // namespace topi
#endif // TOPI_X86_INJECTIVE_H_
......@@ -2,6 +2,7 @@
"""Setup TOPI package."""
from __future__ import absolute_import
import sys
import os
from setuptools import find_packages
from setuptools.dist import Distribution
......@@ -13,7 +14,40 @@ else:
from setuptools import setup
from setuptools.extension import Extension
__version__ = "0.1.0"
def get_lib_names():
if sys.platform.startswith('win32'):
return ['libtvm_topi.dll', 'tvm_topi.dll']
if sys.platform.startswith('darwin'):
return ['libtvm_topi.dylib', 'tvm_topi.dylib']
return ['libtvm_topi.so', 'tvm_topi.so']
def get_lib_path():
"""Get library path, name and version"""
# We can not import `libinfo.py` in setup.py directly since __init__.py
# Will be invoked which introduces dependences
CURRENT_DIR = os.path.dirname(__file__)
libinfo_py = os.path.join(CURRENT_DIR, '../../python/tvm/_ffi/libinfo.py')
libinfo = {'__file__': libinfo_py}
exec(compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), libinfo, libinfo)
lib_path = libinfo['find_lib_path'](get_lib_names())
version = libinfo['__version__']
libs = [lib_path[0]]
if libs[0].find("runtime") == -1:
for name in lib_path[1:]:
if name.find("runtime") != -1:
libs.append(name)
break
return libs, version
LIB_LIST, __version__ = get_lib_path()
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
for i, path in enumerate(LIB_LIST):
LIB_LIST[i] = os.path.relpath(path, curr_path)
setup_kwargs = {
"include_package_data": True,
"data_files": [('topi', LIB_LIST)]
}
setup(name='topi',
version=__version__,
......@@ -23,4 +57,5 @@ setup(name='topi',
"decorator",
],
packages=find_packages(),
url='https://github.com/dmlc/tvm')
url='https://github.com/dmlc/tvm',
**setup_kwargs)
......@@ -9,6 +9,8 @@ specific workload.
"""
from __future__ import absolute_import as _abs
from tvm._ffi.libinfo import __version__
from .math import *
from .reduction import *
from .transform import *
......@@ -21,3 +23,4 @@ from . import mali
from . import testing
from . import util
from . import rocm
from . import cpp
"""FFI for C++ TOPI ops and schedules"""
import sys
import os
import ctypes
from imp import new_module as _new_module
from tvm._ffi.function import _init_api_prefix
from tvm._ffi import libinfo
import tvm as _tvm
def _get_lib_names():
if sys.platform.startswith('win32'):
return ['libtvm_topi.dll', 'tvm_topi.dll']
if sys.platform.startswith('darwin'):
return ['libtvm_topi.dylib', 'tvm_topi.dylib']
return ['libtvm_topi.so', 'tvm_topi.so']
def _load_lib():
"""Load libary by searching possible path."""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
lib_search = curr_path
lib_path = libinfo.find_lib_path(_get_lib_names(), lib_search, optional=True)
if lib_path is None:
return None, None
lib = ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)
return lib, os.path.basename(lib_path[0])
_LIB, _LIB_NAME = _load_lib()
_init_api_prefix("topi.cpp", "topi")
def _create_module(name):
fullname = __name__ + "." + name
mod = _new_module(fullname)
sys.modules[fullname] = mod
return mod
# pylint: disable-msg=C0103
nn = _create_module("nn")
_init_api_prefix("topi.cpp.nn", "topi.nn")
generic = _create_module("generic")
_init_api_prefix("topi.cpp.generic", "topi.generic")
cuda = _create_module("cuda")
_init_api_prefix("topi.cpp.cuda", "topi.cuda")
rocm = _create_module("rocm")
_init_api_prefix("topi.cpp.rocm", "topi.rocm")
x86 = _create_module("x86")
_init_api_prefix("topi.cpp.x86", "topi.x86")
class IntVector(object):
"""Handle to std::vector<int> instance """
_tvm_tcode = 27
def __init__(self, handle):
self.handle = handle
def __del__(self):
_tvm.nd.free_extension_handle(self.handle, 27)
@property
def _tvm_handle(self):
return self.handle.value
def __getitem__(self, idx):
return ivec_get(self, idx)
_tvm.register_extension(IntVector, IntVector)
class Target(object):
"""Handle to C++ Target instance """
_tvm_tcode = 28
def __init__(self, handle):
self.handle = handle
def __del__(self):
_tvm.nd.free_extension_handle(self.handle, 28)
@property
def _tvm_handle(self):
return self.handle.value
def __getitem__(self, idx):
return ivec_get(self, idx)
_tvm.register_extension(Target, Target)
import tvm
import topi
from topi import util
def test_util():
x = tvm.const(100)
assert util.get_const_int(x) == 100
assert util.get_const_tuple((x, x)) == (100, 100)
def test_ewise():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')
def test_apply(func, name):
B = func(A)
assert tuple(B.shape) == tuple(A.shape)
assert B.op.body[0].name == name
test_apply(topi.cpp.exp, "exp")
test_apply(topi.cpp.tanh, "tanh")
test_apply(topi.cpp.sigmoid, "sigmoid")
test_apply(topi.cpp.log, "log")
test_apply(topi.cpp.sqrt, "sqrt")
if __name__ == "__main__":
test_util()
test_ewise()
"""Test code for binary neural network operators."""
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize
def verify_binary_dense(batch, in_dim, out_dim):
A = tvm.placeholder((batch, in_dim), name='A')
B = tvm.placeholder((out_dim, in_dim), name='B')
bnn_A = topi.cpp.nn.binarize_pack(A, 1)
bnn_B = topi.cpp.nn.binarize_pack(B, 1)
# binary dense
bnn_A1 = tvm.placeholder(bnn_A.shape, dtype=bnn_A.dtype)
bnn_B1 = tvm.placeholder(bnn_B.shape, dtype=bnn_B.dtype)
bnn_C = topi.cpp.nn.binary_dense(bnn_A1, bnn_B1)
# schedule
target = topi.cpp.TEST_create_target("llvm")
s1 = topi.cpp.x86.schedule_binarize_pack(target, [bnn_A])
s2 = topi.cpp.x86.schedule_binarize_pack(target, [bnn_B])
s3 = topi.cpp.x86.schedule_binary_dense(target, [bnn_C])
dtype = A.dtype
@memoize("topi.tests.test_topi_binary_dense")
def get_ref_data():
# generate random matrix of +1 or -1 value
a_np = (np.random.randint(2, size=(batch, in_dim)) * 2 - 1).astype(dtype)
b_np = (np.random.randint(2, size=(out_dim, in_dim)) * 2 - 1).astype(dtype)
c_np = np.dot(a_np, b_np.T)
return (a_np, b_np, c_np)
a_np, b_np, c_np = get_ref_data()
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
bnn_a = tvm.nd.array(np.zeros(get_const_tuple(bnn_A.shape), dtype=bnn_A.dtype), ctx)
bnn_b = tvm.nd.array(np.zeros(get_const_tuple(bnn_B.shape), dtype=bnn_B.dtype), ctx)
bnn_c = tvm.nd.array(np.zeros(get_const_tuple(bnn_C.shape), dtype=bnn_C.dtype), ctx)
f1 = tvm.build(s1, [A, bnn_A], 'llvm')
f2 = tvm.build(s2, [B, bnn_B], 'llvm')
f3 = tvm.build(s3, [bnn_A1, bnn_B1, bnn_C], 'llvm')
f1(a, bnn_a)
f2(b, bnn_b)
f3(bnn_a, bnn_b, bnn_c)
np.testing.assert_allclose(bnn_c.asnumpy(), c_np, rtol=1e-5)
def test_binary_dense():
verify_binary_dense(1, 4096, 1024)
verify_binary_dense(1, 1024, 1000)
if __name__ == "__main__":
test_binary_dense()
"""Test code for broadcasting operators."""
import os
import numpy as np
import tvm
import topi
def verify_broadcast_to_ele(in_shape, out_shape):
# Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.cpp.broadcast_to(A, out_shape)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="broadcast_to")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = np.broadcast_to(data_npy, out_shape)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
for _ in range(1):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
check_device("opencl")
check_device("cuda")
#check_device("metal")
#check_device("rocm")
def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
# Build the logic and compile the function
A = tvm.placeholder(shape=lhs_shape, name="A")
B = tvm.placeholder(shape=rhs_shape, name="B")
if typ == "add":
C = topi.cpp.broadcast_add(A, B)
elif typ == "sub":
C = topi.cpp.broadcast_sub(A, B)
elif typ == "div":
C = topi.cpp.broadcast_div(A, B)
elif typ == "mul":
C = topi.cpp.broadcast_mul(A, B)
elif typ == "maximum":
C = topi.cpp.broadcast_maximum(A, B)
elif typ == "minimum":
C = topi.cpp.broadcast_minimum(A, B)
elif typ == "pow":
C = topi.cpp.broadcast_pow(A, B)
else:
raise NotImplementedError
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
s = topi.cpp.cuda.schedule_injective(target, [C])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B, C], device, name="broadcast_binary" + "_" + typ)
lhs_npy = np.random.uniform(size=lhs_shape).astype(A.dtype)
rhs_npy = np.random.uniform(size=rhs_shape).astype(A.dtype)
if typ == "add":
out_npy = lhs_npy + rhs_npy
elif typ == "sub":
out_npy = lhs_npy - rhs_npy
elif typ == "div":
rhs_npy = np.abs(rhs_npy) + 0.001
out_npy = lhs_npy / rhs_npy
elif typ == "mul":
out_npy = lhs_npy * rhs_npy
elif typ == "maximum":
out_npy = np.maximum(lhs_npy, rhs_npy)
elif typ == "minimum":
out_npy = np.minimum(lhs_npy, rhs_npy)
elif typ == "pow":
out_npy = lhs_npy ** rhs_npy
else:
raise NotImplementedError
lhs_nd = tvm.nd.array(lhs_npy, ctx)
rhs_nd = tvm.nd.array(rhs_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
for _ in range(1):
foo(lhs_nd, rhs_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy, rtol=1E-4, atol=1E-4)
check_device("opencl")
check_device("cuda")
#check_device("metal")
#check_device("rocm")
def test_broadcast_to():
verify_broadcast_to_ele((1,), (10,))
verify_broadcast_to_ele((), (10,))
verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4))
verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32))
def test_broadcast_binary():
verify_broadcast_binary_ele((5, 2, 3), (2, 1), typ="add")
verify_broadcast_binary_ele((5, 2, 3), (), typ="add")
verify_broadcast_binary_ele((5, 64, 128), (2, 5, 64, 1), typ="mul")
verify_broadcast_binary_ele((2, 3, 1, 32), (64, 32), typ="div")
verify_broadcast_binary_ele((1, 32), (64, 32), typ="sub")
verify_broadcast_binary_ele((32,), (64, 32), typ="maximum")
verify_broadcast_binary_ele((1, 2, 2, 1, 32), (64, 32), typ="minimum")
verify_broadcast_binary_ele((1, 32), (64, 32), typ="pow")
if __name__ == "__main__":
test_broadcast_to()
test_broadcast_binary()
"""Test code for clip operator"""
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize
from util import make_vector
def verify_clip(N, a_min, a_max, dtype):
A = tvm.placeholder((N, N), dtype=dtype, name='A')
B = topi.cpp.clip(A, a_min, a_max)
# use memoize to pickle the test data for next time use
@memoize("topi.tests.test_topi_clip")
def get_ref_data():
a_np = np.random.uniform(a_min*2, a_max*2, size=(N, N)).astype(dtype)
b_np = np.clip(a_np, a_min, a_max)
return a_np, b_np
a_np, b_np = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
target = topi.cpp.TEST_create_target(device)
s = topi.cpp.generic.default_schedule(target, [B], False)
ctx = tvm.cpu(0) if device == "llvm" else tvm.gpu(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device, name="clip")
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm']:
check_device(device)
def test_clip():
verify_clip(1024, -127, 127, 'int8')
verify_clip(1024, -127, 127, 'int16')
verify_clip(1024, -127, 127, 'float32')
if __name__ == "__main__":
test_clip()
"""Test code for dense operator"""
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize
def verify_dense(batch, in_dim, out_dim, use_bias=True):
A = tvm.placeholder((batch, in_dim), name='A')
B = tvm.placeholder((out_dim, in_dim), name='B')
C = tvm.placeholder((out_dim,), name='C')
D = topi.cpp.nn.dense(A, B, C if use_bias else None)
D = topi.cpp.nn.relu(D)
dtype = A.dtype
# use memoize to pickle the test data for next time use
@memoize("topi.tests.test_topi_dense")
def get_ref_data():
a_np = np.random.uniform(size=(batch, in_dim)).astype(dtype)
b_np = np.random.uniform(size=(out_dim, in_dim)).astype(dtype)
c_np = np.random.uniform(size=(out_dim,)).astype(dtype)
if use_bias:
d_np = np.maximum(np.dot(a_np, b_np.T) + c_np, 0.0)
else:
d_np = np.maximum(np.dot(a_np, b_np.T), 0.0)
return (a_np, b_np, c_np, d_np)
# get the test data
a_np, b_np, c_np, d_np = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.schedule_dense(target, [D])
elif device == "rocm":
s = topi.cpp.rocm.schedule_dense(target, [D])
else:
s = topi.cpp.cuda.schedule_dense(target, [D])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(c_np, ctx)
d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B, C, D], device, name="dense")
f(a, b, c, d)
np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']:
check_device(device)
def test_dense():
verify_dense(1, 1024, 1000, use_bias=True)
verify_dense(1, 1024, 1000, use_bias=False)
if __name__ == "__main__":
test_dense()
import tvm
import topi
import numpy as np
def test_dilate():
target = 'llvm'
ctx = tvm.cpu(0)
def _test_dilate(input_size, strides):
Input = tvm.placeholder((input_size))
Output = topi.cpp.nn.dilate(Input, strides)
tgt = topi.cpp.TEST_create_target(target)
schedule = topi.cpp.generic.default_schedule(tgt, [Output], True)
input_np = np.random.uniform(size=input_size).astype(Input.dtype)
output_np = topi.testing.dilate_python(input_np, strides)
input_tvm = tvm.nd.array(input_np, ctx=ctx)
output_size = topi.util.get_const_tuple(Output.shape)
output_tvm = tvm.nd.array(np.zeros(shape=output_size).astype(Output.dtype), ctx=ctx)
f = tvm.build(schedule, [Input, Output], target)
f(input_tvm, output_tvm)
np.testing.assert_allclose(output_tvm.asnumpy(), output_np, rtol=1e-5)
_test_dilate((32,), (2,))
_test_dilate((32,32), (2,2))
_test_dilate((1,3,32,32), (1,1,1,1))
_test_dilate((1,3,32,32), (2,2,2,2))
_test_dilate((1,32,32,3,3), (1,1,1,1,1))
_test_dilate((1,32,32,3,3), (2,2,2,2,2))
_test_dilate((1,32,32,32,3,3), (1,1,1,2,2,2))
_test_dilate((1,32,32,32,3,3), (2,2,2,1,1,1))
if __name__ == "__main__":
test_dilate()
"""Test code for pooling"""
import numpy as np
import tvm
import topi
import math
from topi.util import get_const_tuple
pool_code = {
"avg": 0,
"max": 1
}
def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode):
iw = ih
kw = kh
sw = sh
ph, pw = padding
A = tvm.placeholder((n, ic, ih, iw), name='A')
B = topi.cpp.nn.pool(A, [kh, kw], [sh, sw], padding,
pool_code[pool_type], ceil_mode)
B = topi.cpp.nn.relu(B)
dtype = A.dtype
bshape = get_const_tuple(B.shape)
ashape = get_const_tuple(A.shape)
if ceil_mode:
assert bshape[2] == int(math.ceil(float(ashape[2] - kh + ph * 2) / sh) + 1)
assert bshape[3] == int(math.ceil(float(ashape[3] - kw + pw * 2) / sw) + 1)
else:
assert bshape[2] == int(math.floor(float(ashape[2] - kh + ph * 2) / sh) + 1)
assert bshape[3] == int(math.floor(float(ashape[3] - kw + pw * 2) / sw) + 1)
a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype)
pad_np = np.zeros(shape=(n, ic, ih+2*ph, iw+2*pw)).astype(dtype)
no_zero = (range(n), range(ic), (range(ph, ih+ph)), (range(pw, iw+pw)))
pad_np[np.ix_(*no_zero)] = a_np
_, oc, oh, ow = get_const_tuple(B.shape)
b_np = np.zeros(shape=(n, oc, oh, ow)).astype(dtype)
if pool_type == 'avg':
for i in range(oh):
for j in range(ow):
b_np[:,:,i,j] = np.mean(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3))
elif pool_type =='max':
for i in range(oh):
for j in range(ow):
b_np[:,:,i,j] = np.max(pad_np[:, :, i*sh:i*sh+kh, j*sw:j*sw+kw], axis=(2,3))
b_np = np.maximum(b_np, 0.0)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.default_schedule(target, [B], False)
else:
s = topi.cpp.cuda.schedule_pool(target, [B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']:
check_device(device)
def test_pool():
verify_pool(1, 256, 32, 2, 2, [0, 0], 'avg', False)
verify_pool(1, 256, 31, 3, 3, [1, 2], 'avg', False)
verify_pool(1, 256, 32, 2, 2, [0, 0], 'max', False)
verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', False)
verify_pool(1, 256, 31, 3, 3, [2, 1], 'max', True)
def verify_global_pool(n, c, h, w, pool_type):
A = tvm.placeholder((n, c, h, w), name='A')
B = topi.cpp.nn.global_pool(A, pool_code[pool_type])
B = topi.cpp.nn.relu(B)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
if pool_type == 'avg':
b_np = np.mean(a_np, axis=(2,3), keepdims=True)
elif pool_type =='max':
b_np = np.max(a_np, axis=(2,3), keepdims=True)
b_np = np.maximum(b_np, 0.0)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.default_schedule(target, [B], False)
else:
s = topi.cpp.cuda.schedule_global_pool(target, [B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']:
check_device(device)
def test_global_pool():
verify_global_pool(1, 1024, 7, 7, 'avg')
verify_global_pool(4, 1024, 7, 7, 'avg')
verify_global_pool(1, 1024, 7, 7, 'max')
verify_global_pool(4, 1024, 7, 7, 'max')
if __name__ == "__main__":
test_pool()
test_global_pool()
"""Test code for reduce."""
import os
import numpy as np
import tvm
import topi
def _my_npy_argmax(arr, axis, keepdims):
if not keepdims:
return arr.argmax(axis=axis)
else:
if axis is not None:
out_shape = list(arr.shape)
out_shape[axis] = 1
else:
out_shape = [1 for _ in range(len(arr.shape))]
return arr.argmax(axis=axis).reshape(out_shape)
def _my_npy_argmin(arr, axis, keepdims):
if not keepdims:
return arr.argmin(axis=axis)
else:
out_shape = list(arr.shape)
out_shape[axis] = 1
return arr.argmin(axis=axis).reshape(out_shape)
def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
# Build the logic and compile the function
dat_dtype = "float32"
A = tvm.placeholder(shape=in_shape, name="A", dtype=dat_dtype)
A1 = topi.cpp.sqrt(topi.cpp.exp(A))
out_dtype = "float32"
if type == "sum":
B = topi.cpp.sum(A1, axis, keepdims)
elif type == "max":
B = topi.cpp.max(A1, axis, keepdims)
elif type == "min":
B = topi.cpp.min(A1, axis, keepdims)
elif type == "argmax":
B = topi.cpp.argmax(A1, axis, keepdims)
out_dtype = "int32"
elif type == "argmin":
B = topi.cpp.argmin(A1, axis, keepdims)
out_dtype = "int32"
else:
raise NotImplementedError
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.default_schedule(target, [B], True)
else:
s = topi.cpp.cuda.schedule_reduce(target, [B])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="sum")
# Test
in_npy = np.random.uniform(size=in_shape).astype(np.float32)
in_npy_map = np.sqrt(np.exp(in_npy)).astype(np.float32)
if type == "sum":
out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
elif type == "max":
out_npy = in_npy_map.max(axis=axis, keepdims=keepdims)
elif type == "min":
out_npy = in_npy_map.min(axis=axis, keepdims=keepdims)
elif type == "argmax":
out_npy = _my_npy_argmax(in_npy_map, axis=axis, keepdims=keepdims)
elif type == "argmin":
out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims)
else:
raise NotImplementedError
data_tvm = tvm.nd.array(in_npy, ctx=ctx)
out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype)
for _ in range(1):
foo(data_tvm, out_tvm)
np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
for device in ["cuda", "opencl", "metal", "llvm", "rocm"]:
check_device(device)
def test_reduce_map():
verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
axis=(1, 2, 3),
keepdims=True,
type="sum")
verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24),
axis=(1,),
keepdims=False,
type="max")
verify_reduce_map_ele(in_shape=(32, 128, 24),
axis=None,
keepdims=True,
type="sum")
verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
axis=(0, 2),
keepdims=False,
type="min")
verify_reduce_map_ele(in_shape=(32, 128),
axis=1,
keepdims=True,
type="argmax")
verify_reduce_map_ele(in_shape=(32, 24, 32, 24),
axis=2,
keepdims=False,
type="argmin")
verify_reduce_map_ele(in_shape=(31, 21, 15),
axis=None,
keepdims=True,
type="argmax")
verify_reduce_map_ele(in_shape=(31, 21, 15),
axis=None,
keepdims=False,
type="sum")
if __name__ == "__main__":
test_reduce_map()
"""Test code for relu activation"""
import os
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
def verify_relu(m, n):
A = tvm.placeholder((m, n), name='A')
B = topi.cpp.nn.relu(A)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = a_np * (a_np > 0)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.schedule_injective(target, [B])
else:
s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="relu")
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']:
check_device(device)
def verify_leaky_relu(m, alpha):
A = tvm.placeholder((m,), name='A')
B = topi.cpp.nn.leaky_relu(A, alpha)
device = "llvm"
target = topi.cpp.TEST_create_target(device)
s = topi.cpp.generic.schedule_injective(target, [B])
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = a_np * (a_np > 0) + a_np * (a_np < 0) * alpha
ctx = tvm.cpu(0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="leaky_relu")
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
def test_relu():
verify_relu(10, 128)
def test_leaky_relu():
verify_leaky_relu(100, 0.1)
if __name__ == "__main__":
test_relu()
test_leaky_relu()
"""Test code for softmax"""
import os
import numpy as np
import tvm
import topi
import logging
from topi.util import get_const_tuple
def verify_softmax(m, n):
A = tvm.placeholder((m, n), name='A')
B = topi.cpp.nn.softmax(A)
# confirm lower works
s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.softmax_python(a_np)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.default_schedule(target, [B], False)
else:
s = topi.cpp.cuda.schedule_softmax(target, [B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="softmax")
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm']:
check_device(device)
def test_softmax():
verify_softmax(32, 10)
verify_softmax(3, 4)
def verify_log_softmax(m, n):
A = tvm.placeholder((m, n), name='A')
B = topi.cpp.nn.log_softmax(A)
# confirm lower works
s = tvm.create_schedule([B.op])
tvm.lower(s, [A, B], simple_mode=True)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
b_np = topi.testing.log_softmax_python(a_np)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.default_schedule(target, [B], False)
else:
s = topi.cpp.cuda.schedule_softmax(target, [B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
foo = tvm.build(s, [A, B], device, name="log_softmax")
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ["cuda", "opencl", "metal", "rocm"]:
check_device(device)
def test_log_softmax():
verify_log_softmax(32, 10)
verify_log_softmax(3, 4)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
test_softmax()
test_log_softmax()
"""Test code for broadcasting operators."""
import numpy as np
import tvm
import topi
def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.cpp.expand_dims(A, axis, num_newaxis)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.schedule_injective(target, [B])
else:
s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="expand_dims")
data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = data_npy.reshape(out_shape)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device(device)
def verify_tranpose(in_shape, axes):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.cpp.transpose(A, axes)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.schedule_injective(target, [B])
else:
s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="tranpose")
data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
out_npy = data_npy.transpose(axes)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype)
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device(device)
def verify_reshape(src_shape, dst_shape):
A = tvm.placeholder(shape=src_shape, name="A")
B = topi.cpp.reshape(A, dst_shape)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.schedule_injective(target, [B])
else:
s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="reshape")
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.reshape(data_npy, newshape=dst_shape)
data_nd = tvm.nd.array(data_npy, ctx)
out_nd = tvm.nd.empty(dst_shape, ctx=ctx, dtype=B.dtype)
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device(device)
def verify_squeeze(src_shape, axis):
A = tvm.placeholder(shape=src_shape, name="A")
B = topi.cpp.squeeze(A, axis)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.schedule_injective(target, [B])
else:
s = topi.cpp.cuda.schedule_injective(target, [B])
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A, B], device, name="squeeze")
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npy = np.squeeze(data_npy, axis=axis)
data_nd = tvm.nd.array(data_npy, ctx)
if out_npy.shape == ():
out_nd_shape = (1,)
else:
out_nd_shape = out_npy.shape
out_nd = tvm.nd.empty(out_nd_shape, ctx=ctx, dtype=B.dtype)
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device(device)
def verify_concatenate(shapes, axis):
tensor_l = []
for i, shape in enumerate(shapes):
tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
out_tensor = topi.cpp.concatenate(tensor_l, axis)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.schedule_injective(target, [out_tensor])
else:
s = topi.cpp.cuda.schedule_injective(target, [out_tensor])
ctx = tvm.context(device, 0)
foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
out_npy = np.concatenate(data_npys, axis=axis)
data_nds = [tvm.nd.array(data_npy, ctx) for data_npy in data_npys]
out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=out_tensor.dtype)
foo(*(data_nds + [out_nd]))
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device(device)
def verify_split(src_shape, indices_or_sections, axis):
A = tvm.placeholder(shape=src_shape, name="A")
tensor_l = topi.cpp.split(A, indices_or_sections, axis)
tensor_l = list(tensor_l)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
target = topi.cpp.TEST_create_target(device)
if device == "llvm":
s = topi.cpp.generic.schedule_injective(target, tensor_l)
else:
s = topi.cpp.cuda.schedule_injective(target, tensor_l)
ctx = tvm.context(device, 0)
foo = tvm.build(s, [A] + tensor_l, device, name="split")
data_npy = np.random.normal(size=src_shape).astype(A.dtype)
out_npys = np.split(data_npy, indices_or_sections, axis=axis)
data_nd = tvm.nd.array(data_npy, ctx)
out_nds = [tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=tensor_l[0].dtype) for out_npy in out_npys]
foo(*([data_nd] + out_nds))
for out_nd, out_npy in zip(out_nds, out_npys):
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm"]:
check_device(device)
def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
def test_tranpose():
verify_tranpose((3, 10, 2), (1, 0, 2))
verify_tranpose((3, 10, 5), (2, 0, 1))
verify_tranpose((3, 10), None)
def test_reshape():
verify_reshape((1, 2, 3, 4), (2, 3, 4))
verify_reshape((4, 2, 3, 4), (2, 4, 12))
verify_reshape((4, 2, 3, 4), (2, 48))
verify_reshape((16, ), (2, 2, 2, 2))
def test_squeeze():
verify_squeeze((1, 2, 3, 4), 0)
verify_squeeze((1, 2, 1, 4), None)
verify_squeeze((1, 1, 1, 4), (1, 2))
verify_squeeze((1, 1, 1, 1), None)
def test_concatenate():
verify_concatenate([(2,), (2,), (2,)], 0)
verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1)
verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1)
verify_concatenate([(5, 6, 7, 3),
(16, 6, 7, 3),
(12, 6, 7, 3),
(8, 6, 7, 3),
(2, 6, 7, 3)], 0)
def test_split():
verify_split((2, 12, 3), 3, 1)
verify_split((2, 12, 3), [2, 4], 1)
verify_split((10, 12, 24), [5, 7, 9], -1)
if __name__ == "__main__":
test_concatenate()
test_tranpose()
test_expand_dims()
test_reshape()
test_squeeze()
test_split()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment