Commit 8f240ee7 by Tianqi Chen Committed by GitHub

[CODEGEN/LLVM] Initial support for codegen LLVM. (#49)

* [LLVM] Initial support for codegen LLVM.

* Fix the naming issue of codegen
parent 3555769e
Subproject commit e68ae61cd541ac29efc9fafe2ad061479bcaa9c9 Subproject commit 1a11a6c2522b1d11a5ccdb9b4fe3976cbe7f9f27
ifndef config ifndef config
ifneq ("$(wildcard ./config.mk)","") ifneq ("$(wildcard ./config.mk)","")
config = config.mk config ?= config.mk
else else
config = make/config.mk config ?= make/config.mk
endif endif
endif endif
...@@ -19,24 +19,16 @@ SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc) ...@@ -19,24 +19,16 @@ SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR) ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
ifneq ($(USE_CUDA_PATH), NONE)
NVCC=$(USE_CUDA_PATH)/bin/nvcc
endif
export LDFLAGS = -pthread -lm export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2\ export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC -Iinclude -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
export FRAMEWORKS=
ifneq ($(ADD_CFLAGS), NONE)
CFLAGS += $(ADD_CFLAGS)
endif
ifneq ($(ADD_LDFLAGS), NONE) ifdef CUDA_PATH
LDFLAGS += $(ADD_LDFLAGS) NVCC=$(CUDA_PATH)/bin/nvcc
CFLAGS += -I$(CUDA_PATH)/include
LDFLAGS += -L$(CUDA_PATH)/lib64
endif endif
ifeq ($(USE_CUDA), 1) ifeq ($(USE_CUDA), 1)
CFLAGS += -DTVM_CUDA_RUNTIME=1 CFLAGS += -DTVM_CUDA_RUNTIME=1
LDFLAGS += -lcuda -lcudart -lnvrtc LDFLAGS += -lcuda -lcudart -lnvrtc
...@@ -44,6 +36,7 @@ else ...@@ -44,6 +36,7 @@ else
CFLAGS += -DTVM_CUDA_RUNTIME=0 CFLAGS += -DTVM_CUDA_RUNTIME=0
endif endif
FRAMEWORKS=
ifeq ($(USE_OPENCL), 1) ifeq ($(USE_OPENCL), 1)
CFLAGS += -DTVM_OPENCL_RUNTIME=1 CFLAGS += -DTVM_OPENCL_RUNTIME=1
...@@ -57,6 +50,23 @@ else ...@@ -57,6 +50,23 @@ else
CFLAGS += -DTVM_OPENCL_RUNTIME=0 CFLAGS += -DTVM_OPENCL_RUNTIME=0
endif endif
# llvm configuration
LLVM_CONFIG=llvm-config
ifeq ($(USE_LLVM), 1)
LLVM_VERSION=$(shell $(LLVM_CONFIG) --version| cut -b 1,3)
LLVM_INCLUDE=$(filter -I%, $(shell $(LLVM_CONFIG) --cxxflags))
LDFLAGS += $(shell $(LLVM_CONFIG) --ldflags --libs --system-libs)
CFLAGS += $(LLVM_INCLUDE) -DTVM_LLVM_VERSION=$(LLVM_VERSION)
endif
ifdef $(ADD_CFLAGS)
CFLAGS += $(ADD_CFLAGS)
endif
ifdef $(ADD_LDFLAGS)
LDFLAGS += $(ADD_LDFLAGS)
endif
include tests/cpp/unittest.mk include tests/cpp/unittest.mk
......
Subproject commit 3a51614d39b69fdb5de1efcf1016426626d267a6 Subproject commit 8dd365636528175e785448cf8a9f4e494c8ee0e0
...@@ -90,7 +90,7 @@ class BufferNode : public Node { ...@@ -90,7 +90,7 @@ class BufferNode : public Node {
Type dtype); Type dtype);
static constexpr const char* _type_key = "Buffer"; static constexpr const char* _type_key = "Buffer";
TVM_DECLARE_NODE_TYPE_INFO(BufferNode); TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node);
}; };
inline const BufferNode* Buffer::operator->() const { inline const BufferNode* Buffer::operator->() const {
......
...@@ -32,6 +32,13 @@ PackedFunc BuildStackVM( ...@@ -32,6 +32,13 @@ PackedFunc BuildStackVM(
const std::unordered_map<LoweredFunc, PackedFunc>& device_funcs); const std::unordered_map<LoweredFunc, PackedFunc>& device_funcs);
/*! /*!
* \brief Build a LLVM VM function, this is still beta
* \param func The LoweredFunc to be build
* \return A packed function representing the func.
*/
PackedFunc BuildLLVM(LoweredFunc func);
/*!
* \brief Build a CUDA function with NVRTC * \brief Build a CUDA function with NVRTC
* *
* \param fsplits The LoweredFuncs to be build (after SplitHostDevice) * \param fsplits The LoweredFuncs to be build (after SplitHostDevice)
......
...@@ -36,7 +36,6 @@ using Halide::Internal::make_zero; ...@@ -36,7 +36,6 @@ using Halide::Internal::make_zero;
using Halide::Internal::as_const_int; using Halide::Internal::as_const_int;
using Halide::Internal::as_const_uint; using Halide::Internal::as_const_uint;
inline Type TVMType2Type(TVMType t) { inline Type TVMType2Type(TVMType t) {
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes); return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
} }
...@@ -182,7 +181,7 @@ class IterVarNode : public Node { ...@@ -182,7 +181,7 @@ class IterVarNode : public Node {
static IterVar make(Range dom, Var var, std::string thread_tag); static IterVar make(Range dom, Var var, std::string thread_tag);
static constexpr const char* _type_key = "IterVar"; static constexpr const char* _type_key = "IterVar";
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode); TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node);
}; };
// inline implementations // inline implementations
......
...@@ -200,6 +200,8 @@ using Halide::Internal::Realize; ...@@ -200,6 +200,8 @@ using Halide::Internal::Realize;
using Halide::Internal::Block; using Halide::Internal::Block;
using Halide::Internal::IfThenElse; using Halide::Internal::IfThenElse;
using Halide::Internal::Evaluate; using Halide::Internal::Evaluate;
// ir functions
using Halide::Internal::is_const_power_of_two_integer;
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
......
...@@ -92,7 +92,7 @@ class LoweredFuncNode : public FunctionBaseNode { ...@@ -92,7 +92,7 @@ class LoweredFuncNode : public FunctionBaseNode {
} }
static constexpr const char* _type_key = "LoweredFunc"; static constexpr const char* _type_key = "LoweredFunc";
TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode); TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode, Node);
}; };
// Implementations of inline functions // Implementations of inline functions
......
...@@ -39,7 +39,7 @@ class PlaceholderOpNode : public OperationNode { ...@@ -39,7 +39,7 @@ class PlaceholderOpNode : public OperationNode {
Type dtype); Type dtype);
static constexpr const char* _type_key = "PlaceholderOp"; static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode); TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode);
}; };
/*! /*!
...@@ -74,7 +74,7 @@ class ComputeOpNode : public OperationNode { ...@@ -74,7 +74,7 @@ class ComputeOpNode : public OperationNode {
Expr body); Expr body);
static constexpr const char* _type_key = "ComputeOp"; static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode); TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
}; };
/*! /*!
...@@ -123,7 +123,7 @@ class ScanOpNode : public OperationNode { ...@@ -123,7 +123,7 @@ class ScanOpNode : public OperationNode {
Array<Tensor> state_placeholder); Array<Tensor> state_placeholder);
static constexpr const char* _type_key = "ScanOp"; static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode); TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode);
}; };
......
...@@ -33,7 +33,7 @@ struct NodeTypeChecker { ...@@ -33,7 +33,7 @@ struct NodeTypeChecker {
// It can be turned off, but will make non strict checking. // It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI // TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType; using ContainerType = typename T::ContainerType;
return (dynamic_cast<ContainerType*>(sptr) != nullptr); return sptr->derived_from<ContainerType>();
} }
static inline void PrintName(std::ostringstream& os) { // NOLINT(*) static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType; using ContainerType = typename T::ContainerType;
......
...@@ -153,6 +153,13 @@ typedef void* TVMRetValueHandle; ...@@ -153,6 +153,13 @@ typedef void* TVMRetValueHandle;
typedef TVMArray* TVMArrayHandle; typedef TVMArray* TVMArrayHandle;
/*! /*!
* \brief Used for implementing C API function.
* Set last error message before return.
* \param msg The error message to be set.
*/
TVM_DLL void TVMAPISetLastError(const char* msg);
/*!
* \brief return str message of the last error * \brief return str message of the last error
* all function in this file will return 0 when success * all function in this file will return 0 when success
* and -1 when an error occured, * and -1 when an error occured,
...@@ -287,10 +294,10 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, ...@@ -287,10 +294,10 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret,
* \param num_args Number of arguments. * \param num_args Number of arguments.
* \param ret The return value handle. * \param ret The return value handle.
* \param resource_handle The handle additional resouce handle from fron-end. * \param resource_handle The handle additional resouce handle from fron-end.
* * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
* \sa TVMCFuncSetReturn * \sa TVMCFuncSetReturn
*/ */
typedef void (*TVMPackedCFunc)( typedef int (*TVMPackedCFunc)(
TVMValue* args, TVMValue* args,
int* type_codes, int* type_codes,
int num_args, int num_args,
......
...@@ -331,7 +331,7 @@ class StageNode : public Node { ...@@ -331,7 +331,7 @@ class StageNode : public Node {
} }
static constexpr const char* _type_key = "Stage"; static constexpr const char* _type_key = "Stage";
TVM_DECLARE_NODE_TYPE_INFO(StageNode); TVM_DECLARE_NODE_TYPE_INFO(StageNode, Node);
}; };
/*! \brief node container for schedule */ /*! \brief node container for schedule */
...@@ -354,7 +354,7 @@ class ScheduleNode : public Node { ...@@ -354,7 +354,7 @@ class ScheduleNode : public Node {
} }
static constexpr const char* _type_key = "Schedule"; static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode); TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
}; };
/*! \brief node container for IterVar attr */ /*! \brief node container for IterVar attr */
...@@ -368,11 +368,14 @@ class IterVarAttrNode : public Node { ...@@ -368,11 +368,14 @@ class IterVarAttrNode : public Node {
} }
static constexpr const char* _type_key = "IterVarAttr"; static constexpr const char* _type_key = "IterVarAttr";
TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode); TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode, Node);
}; };
/*! \brief base node of iteration var */ /*! \brief base node of iteration var */
class IterVarRelationNode : public Node { class IterVarRelationNode : public Node {
public:
static constexpr const char* _type_key = "IterVarRelation";
TVM_DECLARE_BASE_NODE_INFO(IterVarRelationNode, Node);
}; };
/*! /*!
...@@ -402,7 +405,7 @@ class SplitNode : public IterVarRelationNode { ...@@ -402,7 +405,7 @@ class SplitNode : public IterVarRelationNode {
IterVar inner, Expr factor); IterVar inner, Expr factor);
static constexpr const char* _type_key = "Split"; static constexpr const char* _type_key = "Split";
TVM_DECLARE_NODE_TYPE_INFO(SplitNode); TVM_DECLARE_NODE_TYPE_INFO(SplitNode, IterVarRelationNode);
}; };
/*! /*!
...@@ -427,7 +430,7 @@ class FuseNode : public IterVarRelationNode { ...@@ -427,7 +430,7 @@ class FuseNode : public IterVarRelationNode {
IterVar outer, IterVar inner, IterVar fused); IterVar outer, IterVar inner, IterVar fused);
static constexpr const char* _type_key = "Fuse"; static constexpr const char* _type_key = "Fuse";
TVM_DECLARE_NODE_TYPE_INFO(FuseNode); TVM_DECLARE_NODE_TYPE_INFO(FuseNode, IterVarRelationNode);
}; };
/*! /*!
...@@ -450,7 +453,7 @@ class RebaseNode : public IterVarRelationNode { ...@@ -450,7 +453,7 @@ class RebaseNode : public IterVarRelationNode {
static IterVarRelation make(IterVar parent, IterVar rebased); static IterVarRelation make(IterVar parent, IterVar rebased);
static constexpr const char* _type_key = "Rebase"; static constexpr const char* _type_key = "Rebase";
TVM_DECLARE_NODE_TYPE_INFO(RebaseNode); TVM_DECLARE_NODE_TYPE_INFO(RebaseNode, IterVarRelationNode);
}; };
......
...@@ -153,7 +153,7 @@ class TensorNode : public Node { ...@@ -153,7 +153,7 @@ class TensorNode : public Node {
int value_index); int value_index);
static constexpr const char* _type_key = "Tensor"; static constexpr const char* _type_key = "Tensor";
TVM_DECLARE_NODE_TYPE_INFO(TensorNode); TVM_DECLARE_NODE_TYPE_INFO(TensorNode, Node);
}; };
/*! /*!
...@@ -167,8 +167,6 @@ class OperationNode : public FunctionBaseNode { ...@@ -167,8 +167,6 @@ class OperationNode : public FunctionBaseNode {
const std::string& func_name() const final { const std::string& func_name() const final {
return name; return name;
} }
/*! \return number of outputs of this op */
virtual int num_outputs() const = 0;
/*! \return the list of iteration variable at root */ /*! \return the list of iteration variable at root */
virtual Array<IterVar> root_iter_vars() const = 0; virtual Array<IterVar> root_iter_vars() const = 0;
/*! \return type of i-th output */ /*! \return type of i-th output */
...@@ -177,6 +175,8 @@ class OperationNode : public FunctionBaseNode { ...@@ -177,6 +175,8 @@ class OperationNode : public FunctionBaseNode {
virtual Array<Expr> output_shape(size_t i) const = 0; virtual Array<Expr> output_shape(size_t i) const = 0;
static constexpr const char* _type_key = "Operation"; static constexpr const char* _type_key = "Operation";
TVM_DECLARE_BASE_NODE_INFO(OperationNode, Node);
}; };
// Implementations of inline functions // Implementations of inline functions
......
...@@ -40,10 +40,11 @@ USE_CUDA = 1 ...@@ -40,10 +40,11 @@ USE_CUDA = 1
# whether use OpenCL during compile # whether use OpenCL during compile
USE_OPENCL = 0 USE_OPENCL = 0
# add the path to CUDA library to link and compile flag # whether build with LLVM support
# if you have already add them to environment variable, leave it as NONE # This requires llvm-config to be in your PATH
# USE_CUDA_PATH = /usr/local/cuda # Requires LLVM version >= 4.0
USE_CUDA_PATH = NONE USE_LLVM = 0
# whether use cuda runtime compiling for writing kernels in native language (i.e. Python) # add the path to CUDA library to link and compile flag
USE_NVRTC = 0 # if you have already add them to environment variable.
# CUDA_PATH = /usr/local/cuda
...@@ -56,6 +56,7 @@ def convert_to_tvm_func(pyfunc): ...@@ -56,6 +56,7 @@ def convert_to_tvm_func(pyfunc):
check_call(_LIB.TVMCFuncSetReturn(ret, values[0], ctypes.c_int(tcodes[0]))) check_call(_LIB.TVMCFuncSetReturn(ret, values[0], ctypes.c_int(tcodes[0])))
_ = temp_args _ = temp_args
_ = rv _ = rv
return 0
handle = FunctionHandle() handle = FunctionHandle()
f = TVMPackedCFunc(cfun) f = TVMPackedCFunc(cfun)
......
...@@ -96,7 +96,7 @@ class TVMByteArray(ctypes.Structure): ...@@ -96,7 +96,7 @@ class TVMByteArray(ctypes.Structure):
TVMPackedCFunc = ctypes.CFUNCTYPE( TVMPackedCFunc = ctypes.CFUNCTYPE(
None, ctypes.c_int,
ctypes.POINTER(TVMValue), ctypes.POINTER(TVMValue),
ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int),
ctypes.c_int, ctypes.c_int,
......
...@@ -6,3 +6,4 @@ ...@@ -6,3 +6,4 @@
- arithmetic Arithmetic expression and set simplification - arithmetic Arithmetic expression and set simplification
- pass The optimization pass on the IR structure - pass The optimization pass on the IR structure
- runtime Minimum runtime related codes. - runtime Minimum runtime related codes.
- codegen The code generator
...@@ -37,6 +37,11 @@ TVM_REGISTER_API(_codegen_BuildStackVM) ...@@ -37,6 +37,11 @@ TVM_REGISTER_API(_codegen_BuildStackVM)
std::unordered_map<LoweredFunc, PackedFunc>()); std::unordered_map<LoweredFunc, PackedFunc>());
}); });
TVM_REGISTER_API(_codegen_BuildLLVM)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = BuildLLVM(args[0]);
});
TVM_REGISTER_API(_codegen_BuildNVRTC) TVM_REGISTER_API(_codegen_BuildNVRTC)
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = BuildNVRTC(args[0], args[1]); *ret = BuildNVRTC(args[0], args[1]);
......
...@@ -20,7 +20,7 @@ enum SignType { ...@@ -20,7 +20,7 @@ enum SignType {
}; };
// internal node container of int set. // internal node container of int set.
class IntSetNode; struct IntSetNode;
/*! /*!
* \brief Integer set class, represent a set of integers in one dimension. * \brief Integer set class, represent a set of integers in one dimension.
...@@ -104,6 +104,8 @@ class IntSet : public NodeRef { ...@@ -104,6 +104,8 @@ class IntSet : public NodeRef {
* \brief Base class of all IntSet containers. * \brief Base class of all IntSet containers.
*/ */
struct IntSetNode : public Node { struct IntSetNode : public Node {
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
}; };
using ExprIntSetMap = std::unordered_map<Expr, IntSet, using ExprIntSetMap = std::unordered_map<Expr, IntSet,
......
...@@ -35,7 +35,7 @@ struct IntervalSet : public IntSetNode { ...@@ -35,7 +35,7 @@ struct IntervalSet : public IntSetNode {
} }
static constexpr const char* _type_key = "IntervalSet"; static constexpr const char* _type_key = "IntervalSet";
TVM_DECLARE_NODE_TYPE_INFO(IntervalSet); TVM_DECLARE_NODE_TYPE_INFO(IntervalSet, IntSetNode);
}; };
/*! /*!
...@@ -51,7 +51,7 @@ struct StrideSet : public IntSetNode { ...@@ -51,7 +51,7 @@ struct StrideSet : public IntSetNode {
Array<Expr> strides; Array<Expr> strides;
static constexpr const char* _type_key = "StrideSet"; static constexpr const char* _type_key = "StrideSet";
TVM_DECLARE_NODE_TYPE_INFO(StrideSet); TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode);
}; };
} // namespace arith } // namespace arith
......
...@@ -272,9 +272,6 @@ inline void PushBinary(StackVM::OpCode op_int64, ...@@ -272,9 +272,6 @@ inline void PushBinary(StackVM::OpCode op_int64,
} }
} }
inline void PushCast(Type dst, inline void PushCast(Type dst,
Type src, Type src,
CodeGenStackVM* p) { CodeGenStackVM* p) {
...@@ -496,7 +493,5 @@ TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable) ...@@ -496,7 +493,5 @@ TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
.set_dispatch<Call>([](const Call *op, CodeGenStackVM* p) { .set_dispatch<Call>([](const Call *op, CodeGenStackVM* p) {
p->Push_(op); p->Push_(op);
}); });
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_llvm.cc
*/
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/c_runtime_api.h>
#include "./codegen_llvm.h"
#include "../../arithmetic/compute_expr.h"
namespace tvm {
namespace codegen {
void CodeGenLLVM::Init(const std::string& module_name,
llvm::LLVMContext* ctx) {
InitializeLLVM();
static_assert(sizeof(TVMValue) == sizeof(double), "invariant");
static_assert(alignof(TVMValue) == alignof(double), "invariant");
// clear maps
var_map_.clear();
str_map_.clear();
func_handle_map_.clear();
// initialize types.
if (ctx_ != ctx) {
t_void_ = llvm::Type::getVoidTy(*ctx);
t_void_p_ = llvm::Type::getInt8Ty(*ctx)->getPointerTo();
t_int_ = llvm::Type::getIntNTy(*ctx, sizeof(int) * 8);
t_char_ = llvm::Type::getInt8Ty(*ctx);
t_int8_ = llvm::Type::getInt8Ty(*ctx);
t_int16_ = llvm::Type::getInt16Ty(*ctx);
t_int32_ = llvm::Type::getInt32Ty(*ctx);
t_float64_ = llvm::Type::getDoubleTy(*ctx);
t_tvm_index_ = llvm::Type::getIntNTy(*ctx, sizeof(tvm_index_t) * 8);
t_tvm_context_ = llvm::StructType::create({t_int_, t_int_});
t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_});
t_tvm_func_handle_ = t_void_p_;
t_tvm_array_ = llvm::StructType::create(
{t_void_p_,
t_tvm_index_->getPointerTo(),
t_tvm_index_->getPointerTo(),
t_tvm_index_,
t_tvm_type_,
t_tvm_context_});
t_tvm_value_ = llvm::StructType::create({t_float64_});
md_builder_.reset(new llvm::MDBuilder(*ctx));
md_very_likely_branch_ =
md_builder_->createBranchWeights(1 << 30, 0);
md_tbaa_root_ = md_builder_->createTBAARoot("tvmtbaa");
}
ctx_ = ctx;
// initialize modules
module_.reset(new llvm::Module(module_name, *ctx));
// initialize TVM runtime API
f_tvm_func_call_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {
t_tvm_func_handle_,
t_tvm_value_->getPointerTo(),
t_int_->getPointerTo(),
t_int_,
t_tvm_value_->getPointerTo(),
t_int_->getPointerTo()}, false),
llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get());
f_tvm_func_get_global_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {
t_char_->getPointerTo(),
t_tvm_func_handle_->getPointerTo()}, false),
llvm::Function::ExternalLinkage, "TVMFuncGetGlobal", module_.get());
f_tvm_api_set_last_error_ = llvm::Function::Create(
llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false),
llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
// initialize builder
builder_.reset(new IRBuilder(*ctx));
}
void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
var_map_.clear();
CHECK(!module_->getFunction(f->name))
<< "Function " << f->name << "already exists in module";
std::vector<llvm::Type*> arg_type;
for (Var arg : f->args) {
Type t = arg.type();
if (t.is_handle() && f->handle_data_type.count(arg)) {
arg_type.push_back(
LLVMType(f->handle_data_type[arg].type())->getPointerTo());
} else {
arg_type.push_back(LLVMType(t));
}
}
llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, arg_type, false);
// setup the function.
function_ = llvm::cast<llvm::Function>(module_->getOrInsertFunction(f->name, ftype));
function_->setCallingConv(llvm::CallingConv::C);
size_t idx = 0;
for (auto it = function_->arg_begin();
it != function_->arg_end(); ++it, ++idx) {
llvm::Argument* v = &(*it);
var_map_[f->args[idx].get()] = v;
}
llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(block);
this->Visit(f->body);
builder_->CreateRet(ConstInt32(0));
}
class FPassManager : public llvm::legacy::FunctionPassManager {
public:
explicit FPassManager(llvm::Module* m)
: llvm::legacy::FunctionPassManager(m) {}
// override add to allow messaging
void add(llvm::Pass* p) final {
llvm::legacy::FunctionPassManager::add(p);
}
};
class MPassManager : public llvm::legacy::PassManager {
public:
// override add to allow messaging
void add(llvm::Pass* p) final {
llvm::legacy::PassManager::add(p);
}
};
void CodeGenLLVM::Optimize() {
// place optimization pass
llvm::PassManagerBuilder builder;
builder.OptLevel = 3;
builder.Inliner = llvm::createFunctionInliningPass(builder.OptLevel, 0);
builder.LoopVectorize = true;
builder.SLPVectorize = true;
// pass manager
FPassManager fpass(module_.get());
MPassManager mpass;
builder.populateFunctionPassManager(fpass);
builder.populateModulePassManager(mpass);
fpass.doInitialization();
for (auto it = module_->begin(); it != module_->end(); ++it) {
fpass.run(*it);
}
fpass.doFinalization();
mpass.run(*module_);
}
std::unique_ptr<llvm::Module> CodeGenLLVM::Finish() {
this->Optimize();
var_map_.clear();
str_map_.clear();
func_handle_map_.clear();
return std::move(module_);
}
llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
llvm::Type* ret = nullptr;
if (t.is_uint() || t.is_int()) {
ret = llvm::Type::getIntNTy(*ctx_, t.bits());
} else if (t.is_float()) {
switch (t.bits()) {
case 16: ret = llvm::Type::getHalfTy(*ctx_); break;
case 32: ret = llvm::Type::getFloatTy(*ctx_); break;
case 64: ret = llvm::Type::getDoubleTy(*ctx_); break;
default: LOG(FATAL) << "cannot handle " << t;
}
} else {
CHECK(t.is_handle());
ret = t_void_p_;
}
if (t.lanes() != 1) {
ret = llvm::VectorType::get(ret, t.lanes());
}
return ret;
}
void CodeGenLLVM::Visit_(const Variable* op) {
value_ = GetVarValue(op);
}
void CodeGenLLVM::Visit_(const Cast* op) {
value_ = CreateCast(op->value.type(), op->type, MakeValue(op->value));
}
void CodeGenLLVM::Visit_(const IntImm* op) {
value_ = llvm::ConstantInt::getSigned(LLVMType(op->type), op->value);
}
void CodeGenLLVM::Visit_(const UIntImm* op) {
value_ = llvm::ConstantInt::get(LLVMType(op->type), op->value);
}
void CodeGenLLVM::Visit_(const FloatImm* op) {
value_ = llvm::ConstantFP::get(LLVMType(op->type), op->value);
}
void CodeGenLLVM::Visit_(const StringImm* op) {
value_ = GetConstString(op->value);
}
#define DEFINE_CODEGEN_BINARY_OP(OP) \
llvm::Value* CodeGenLLVM::Create ## OP( \
Type t, llvm::Value* a, llvm::Value *b) { \
if (t.is_float()) { \
return builder_->CreateF ## OP (a, b); \
} else if (t.is_int() && t.bits() >= 32) { \
return builder_->CreateNSW ## OP (a, b); \
} else { \
return builder_->Create ## OP (a, b); \
} \
} \
DEFINE_CODEGEN_BINARY_OP(Add);
DEFINE_CODEGEN_BINARY_OP(Sub);
DEFINE_CODEGEN_BINARY_OP(Mul);
void CodeGenLLVM::Visit_(const Add* op) {
value_ = CreateAdd(op->type, MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const Sub* op) {
value_ = CreateSub(op->type, MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const Mul* op) {
value_ = CreateMul(op->type, MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const Div* op) {
llvm::Value* a = MakeValue(op->a);
int shift;
if (op->type.is_float()) {
value_ = builder_->CreateFDiv(a, MakeValue(op->b));
} else if ((op->type.is_int() || op->type.is_uint()) &&
is_const_power_of_two_integer(op->b, &shift)) {
value_ = builder_->CreateAShr(a, shift);
} else {
llvm::Value* b = MakeValue(op->b);
if (op->type.is_int()) {
value_ = builder_->CreateSDiv(a, b);
} else {
CHECK(op->type.is_uint());
value_ = builder_->CreateUDiv(a, b);
}
}
}
void CodeGenLLVM::Visit_(const Mod* op) {
CHECK(!op->type.is_float())
<< "Cannot do mod for float";
if (op->type.is_int()) {
value_ = builder_->CreateSRem(MakeValue(op->a), MakeValue(op->b));
} else {
CHECK(op->type.is_uint());
value_ = builder_->CreateURem(MakeValue(op->a), MakeValue(op->b));
}
}
void CodeGenLLVM::Visit_(const Min* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
llvm::Value* cond = CreateLT(op->a.type(), a, b);
value_ = builder_->CreateSelect(cond, a, b);
}
void CodeGenLLVM::Visit_(const Max* op) {
llvm::Value* a = MakeValue(op->a);
llvm::Value* b = MakeValue(op->b);
llvm::Value* cond = CreateGT(op->a.type(), a, b);
value_ = builder_->CreateSelect(cond, a, b);
}
#define DEFINE_CODEGEN_CMP_OP(OP) \
llvm::Value* CodeGenLLVM::Create ## OP( \
Type t, llvm::Value* a, llvm::Value* b) { \
if (t.is_float()) { \
return builder_->CreateFCmpO ## OP (a, b); \
} else if (t.is_int()) { \
return builder_->CreateICmpS ## OP (a, b); \
} else { \
return builder_->CreateICmpU ## OP (a, b); \
} \
} \
DEFINE_CODEGEN_CMP_OP(LT);
DEFINE_CODEGEN_CMP_OP(LE);
DEFINE_CODEGEN_CMP_OP(GT);
DEFINE_CODEGEN_CMP_OP(GE);
void CodeGenLLVM::Visit_(const LT* op) {
value_ = CreateLT(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const LE* op) {
value_ = CreateLE(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const GT* op) {
value_ = CreateGT(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const GE* op) {
value_ = CreateGE(op->a.type(), MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const EQ* op) {
if (op->a.type().is_float()) {
value_ = builder_->CreateFCmpOEQ(MakeValue(op->a), MakeValue(op->b));
} else {
value_ = builder_->CreateICmpEQ(MakeValue(op->a), MakeValue(op->b));
}
}
void CodeGenLLVM::Visit_(const NE* op) {
if (op->a.type().is_float()) {
value_ = builder_->CreateFCmpONE(MakeValue(op->a), MakeValue(op->b));
} else {
value_ = builder_->CreateICmpNE(MakeValue(op->a), MakeValue(op->b));
}
}
void CodeGenLLVM::Visit_(const And* op) {
value_ = builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const Or* op) {
value_ = builder_->CreateOr(MakeValue(op->a), MakeValue(op->b));
}
void CodeGenLLVM::Visit_(const Not* op) {
value_ = builder_->CreateNot(MakeValue(op->a));
}
void CodeGenLLVM::Visit_(const Select* op) {
value_ = builder_->CreateSelect(
MakeValue(op->condition),
MakeValue(op->true_value),
MakeValue(op->false_value));
}
void CodeGenLLVM::Visit_(const Let* op) {
llvm::Value* v = MakeValue(op->value);
CHECK(!var_map_.count(op->var.get()));
var_map_[op->var.get()] = v;
value_ = MakeValue(op->body);
}
void CodeGenLLVM::Visit_(const Broadcast* op) {
value_ = CreateBroadcast(MakeValue(op->value), op->lanes);
}
void CodeGenLLVM::Visit_(const Ramp* op) {
Type t = op->type;
llvm::Value* base = MakeValue(op->base);
llvm::Value* stride = MakeValue(op->stride);
llvm::Value* value = llvm::UndefValue::get(LLVMType(t));
for (int i = 0; i < t.lanes(); ++i) {
if (i != 0) {
base = CreateAdd(t, base, stride);
}
value = builder_->CreateInsertElement(
value, base, llvm::ConstantInt::get(t_int32_, i));
}
value_ = value;
}
void CodeGenLLVM::Visit_(const Load* op) {
Type t = op->type;
CHECK(!t.is_vector());
if (t.is_scalar()) {
llvm::DataLayout layout(module_.get());
uint64_t valign = layout.getTypeAllocSize(LLVMType(t));
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
CreateBufferPtr(
t,
GetVarValue(op->buffer_var.get()),
MakeValue(op->index)),
valign);
AddAliasInfo(inst, op->buffer_var.get(), op->index);
value_ = inst;
} else {
LOG(FATAL) << "not yet supported";
}
}
void CodeGenLLVM::Visit_(const Store* op) {
llvm::Value* value = MakeValue(op->value);
Type t = op->value.type();
CHECK(!t.is_vector());
if (t.is_scalar()) {
llvm::DataLayout layout(module_.get());
uint64_t valign = layout.getTypeAllocSize(value->getType());
llvm::StoreInst* inst = builder_->CreateAlignedStore(
value,
CreateBufferPtr(
t,
GetVarValue(op->buffer_var.get()),
MakeValue(op->index)),
valign);
AddAliasInfo(inst, op->buffer_var.get(), op->index);
} else {
LOG(FATAL) << "not yet supported";
}
}
void CodeGenLLVM::Visit_(const Call* op) {
if (op->is_intrinsic(intrinsic::tvm_call_global) ||
op->is_intrinsic(intrinsic::tvm_call_device)) {
value_ = CreateCallPacked(op);
} else if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
value_ = CreateIntrinstic(op);
} else {
CHECK(op->call_type == Call::Extern ||
op->call_type == Call::PureExtern);
value_ = CreateCallExtern(op);
}
}
llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
if (op->is_intrinsic(Call::bitwise_and)) {
CHECK_EQ(op->args.size(), 2U);
return builder_->CreateAnd(
MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->is_intrinsic(Call::bitwise_xor)) {
CHECK_EQ(op->args.size(), 2U);
return builder_->CreateXor(
MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->is_intrinsic(Call::bitwise_or)) {
CHECK_EQ(op->args.size(), 2U);
return builder_->CreateOr(
MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->is_intrinsic(Call::bitwise_not)) {
CHECK_EQ(op->args.size(), 1U);
return builder_->CreateNot(MakeValue(op->args[0]));
} else if (op->is_intrinsic(Call::shift_left)) {
CHECK_EQ(op->args.size(), 2U);
return builder_->CreateShl(
MakeValue(op->args[0]), MakeValue(op->args[1]));
} else if (op->is_intrinsic(Call::shift_right)) {
CHECK_EQ(op->args.size(), 2U);
if (op->type.is_int()) {
return builder_->CreateAShr(
MakeValue(op->args[0]), MakeValue(op->args[1]));
} else {
return builder_->CreateLShr(
MakeValue(op->args[0]), MakeValue(op->args[1]));
}
} else if (op->is_intrinsic(Call::address_of)) {
const Load *l = op->args[0].as<Load>();
CHECK(op->args.size() == 1 && l);
return CreateBufferPtr(
l->type, GetVarValue(l->buffer_var.get()), MakeValue(l->index));
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
CHECK_EQ(op->args.size(), 1U);
llvm::Value* ptr = MakeValue(op->args[0]);
return builder_->CreateICmpEQ(
ptr, llvm::Constant::getNullValue(ptr->getType()));
} else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) {
CHECK_EQ(op->args.size(), 3U);
CHECK_EQ(op->type.lanes(), 1);
llvm::Value* args = builder_->CreatePointerCast(
MakeValue(op->args[0]), t_tvm_value_->getPointerTo());
llvm::Value* ptr = builder_->CreateInBoundsGEP(
args, MakeValue(op->args[2]));
// always pass via 64 bit pointers
// For handle type, Handle(64) will simply become 32 bit void*
Type value_type = op->type.with_bits(64);
ptr = builder_->CreatePointerCast(
ptr, LLVMType(value_type)->getPointerTo());
llvm::Value* value = builder_->CreateAlignedLoad(ptr, 8);
// cast to the desired type
if (value_type != op->type) {
value = CreateCast(value_type, op->type, value);
}
return value;
} else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) {
CHECK_EQ(op->args.size(), 2U);
llvm::Value* arr = builder_->CreatePointerCast(
MakeValue(op->args[0]), t_tvm_array_->getPointerTo());
llvm::Constant* zero = ConstInt32(0);
llvm::Value* ret = nullptr;
switch (op->args[1].as<IntImm>()->value) {
case intrinsic::kData: {
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(0)}); break;
}
case intrinsic::kShape: {
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(1)}); break;
}
case intrinsic::kStrides: {
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(2)}); break;
}
case intrinsic::kNDim: {
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(3)}); break;
}
case intrinsic::kTypeCode: {
ret = builder_->CreateInBoundsGEP(
arr, {zero, ConstInt32(4), ConstInt32(0)}); break;
}
case intrinsic::kTypeBits: {
ret = builder_->CreateInBoundsGEP(
arr, {zero, ConstInt32(4), ConstInt32(1)}); break;
}
case intrinsic::kTypeLanes: {
ret = builder_->CreateInBoundsGEP(
arr, {zero, ConstInt32(4), ConstInt32(2)}); break;
}
default: LOG(FATAL) << "unknown field code";
}
return builder_->CreateLoad(ret);
} else {
LOG(FATAL) << "Unknown intrinstic " << op->name;
}
return nullptr;
}
llvm::BasicBlock* CodeGenLLVM::CheckPackedCallSuccess(llvm::Value* retcode) {
// create emit codes that checks and load the function.
using llvm::BasicBlock;
BasicBlock* fail_block = BasicBlock::Create(
*ctx_, "call_fail", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "call_end", function_);
llvm::Value* succ = builder_->CreateICmpEQ(
retcode, llvm::ConstantInt::get(t_int_, 0));
builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_);
builder_->SetInsertPoint(fail_block);
// return the code.
builder_->CreateRet(retcode);
// otherwise set it to be new end.
builder_->SetInsertPoint(end_block);
return end_block;
}
void CodeGenLLVM::Visit_(const For* op) {
using llvm::BasicBlock;
BasicBlock* for_head = BasicBlock::Create(
*ctx_, "for_head", function_);
BasicBlock* for_body = BasicBlock::Create(
*ctx_, "for_body", function_);
BasicBlock* for_end = BasicBlock::Create(
*ctx_, "for_end", function_);
BasicBlock* pre_block = builder_->GetInsertBlock();
CHECK(is_zero(op->min));
Type t = op->min.type();
llvm::Value* init = ConstInt32(0);
llvm::Value* extent = MakeValue(op->extent);
builder_->CreateBr(for_head);
builder_->SetInsertPoint(for_head);
llvm::PHINode* index = builder_->CreatePHI(LLVMType(t), 2);
index->addIncoming(init, pre_block);
llvm::Value* cond = CreateLT(t, index, extent);
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_);
// body of for
builder_->SetInsertPoint(for_body);
var_map_[op->loop_var.get()] = index;
this->Visit(op->body);
llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1));
index->addIncoming(next_index, builder_->GetInsertBlock());
builder_->CreateBr(for_head);
// end of for
builder_->SetInsertPoint(for_end);
}
void CodeGenLLVM::Visit_(const IfThenElse* op) {
using llvm::BasicBlock;
BasicBlock* then_block = BasicBlock::Create(
*ctx_, "if_then", function_);
BasicBlock* else_block = BasicBlock::Create(
*ctx_, "if_else", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "if_end", function_);
if (!op->else_case.defined()) {
else_block = end_block;
}
// condition.
llvm::Value* cond = MakeValue(op->condition);
bool likely = true;
if (likely) {
builder_->CreateCondBr(cond, then_block, else_block, md_very_likely_branch_);
} else {
builder_->CreateCondBr(cond, then_block, else_block);
}
// then case.
builder_->SetInsertPoint(then_block);
this->Visit(op->then_case);
builder_->CreateBr(end_block);
// else case.
if (op->else_case.defined()) {
builder_->SetInsertPoint(else_block);
this->Visit(op->else_case);
builder_->CreateBr(end_block);
}
builder_->SetInsertPoint(end_block);
}
void CodeGenLLVM::Visit_(const Allocate* op) {
CHECK(!is_zero(op->condition));
llvm::Value* buf = nullptr;
if (op->new_expr.defined()) {
CHECK_EQ(op->free_function, "nop");
buf = MakeValue(op->new_expr);
} else {
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
buf = builder_->CreateAlloca(
LLVMType(op->type), ConstInt32(constant_size));
}
buf = builder_->CreatePointerCast(buf, LLVMType(op->type)->getPointerTo());
CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
}
void CodeGenLLVM::Visit_(const AttrStmt* op) {
this->Visit(op->body);
}
void CodeGenLLVM::Visit_(const AssertStmt* op) {
using llvm::BasicBlock;
llvm::Value* cond = MakeValue(op->condition);
std::ostringstream os;
os << "Assert fail: " << op->condition;
if (op->message.as<StringImm>()) {
os << ", " << op->message.as<StringImm>()->value;
}
llvm::Value* msg = GetConstString(os.str());
BasicBlock* fail_block = BasicBlock::Create(
*ctx_, "assert_fail", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "assert_end", function_);
builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_);
// fail condition.
builder_->SetInsertPoint(fail_block);
builder_->CreateCall(f_tvm_api_set_last_error_, {msg});
builder_->CreateRet(llvm::ConstantInt::getSigned(t_int32_, -1));
// otherwise set it to be new end.
builder_->SetInsertPoint(end_block);
}
void CodeGenLLVM::Visit_(const LetStmt* op) {
llvm::Value* v = MakeValue(op->value);
CHECK(!var_map_.count(op->var.get()));
var_map_[op->var.get()] = v;
this->Visit(op->body);
}
void CodeGenLLVM::AddAliasInfo(
llvm::Instruction* inst, const Variable* buffer, Expr index) {
int base = 0, width = 0;
// create meta-data for alias analysis
// Use a group of binary tree ranges.
const Ramp* ramp = index.as<Ramp>();
if (ramp) {
int base, stride;
if (arith::GetConstInt(ramp->base, &base) &&
arith::GetConstInt(ramp->stride, &stride)) {
int xwith = ramp->lanes * stride;
width = 1;
while (width < xwith) {
width *= 2;
}
while (base % width) {
base -= base % width;
width *= 2;
}
}
} else {
if (arith::GetConstInt(index, &base)) width = 1;
}
llvm::MDNode* meta = md_tbaa_root_;
std::ostringstream buffer_addr;
buffer_addr << buffer;
meta = md_builder_->createTBAAScalarTypeNode(buffer_addr.str(), meta);
// create a tree-shape access structure.
if (width != 0) {
for (int w = 1024; w >= width; w /= 2) {
int b = (base / w) * w;
std::stringstream os;
os << buffer << ".w" << w << ".b" << b;
meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
}
}
inst->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(meta, meta, 0));
}
llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
llvm::Constant* init = llvm::UndefValue::get(
llvm::VectorType::get(value->getType(), lanes));
llvm::Constant* zero = ConstInt32(0);
value = builder_->CreateInsertElement(init, value, zero);
llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
return builder_->CreateShuffleVector(value, init, mask);
}
llvm::Value* CodeGenLLVM::CreateBufferPtr(
Type t, llvm::Value* buffer, llvm::Value* index) {
llvm::Type* elem_type = buffer->getType();
unsigned address_space = elem_type->getPointerAddressSpace();
llvm::Type* load_type = LLVMType(t)->getPointerTo(address_space);
if (load_type != elem_type) {
buffer = builder_->CreatePointerCast(buffer, load_type);
}
llvm::Constant* cindex = llvm::dyn_cast<llvm::Constant>(index);
if (cindex && cindex->isZeroValue()) {
return buffer;
}
return builder_->CreateInBoundsGEP(buffer, index);
}
llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
llvm::Type * target = LLVMType(to);
if (value->getType() == target) return value;
if (from.is_handle() && from.is_handle()) {
return builder_->CreateBitCast(value, target);
} else if (!from.is_float() && !to.is_float()) {
return builder_->CreateIntCast(value, target, from.is_int());
} else if (from.is_float() && to.is_int()) {
return builder_->CreateFPToSI(value, target);
} else if (from.is_float() && to.is_uint()) {
if (to.bits() < 8) {
value = builder_->CreateFPToUI(value, LLVMType(to.with_bits(8)));
return builder_->CreateIntCast(value, target, false);
} else {
return builder_->CreateFPToUI(value, target);
}
} else if (from.is_int() && to.is_float()) {
return builder_->CreateSIToFP(value, target);
} else if (from.is_uint() && to.is_float()) {
return builder_->CreateUIToFP(value, target);
} else {
CHECK(from.is_float() && to.is_float());
return builder_->CreateFPCast(value, target);
}
}
llvm::Value* CodeGenLLVM::GetPackedFuncHandle(
const std::string& fname, bool global) {
using llvm::BasicBlock;
// We will store the packed function handle in global space.
// Initialize it during the first call.
llvm::DataLayout layout(module_.get());
uint64_t halign = layout.getTypeAllocSize(t_tvm_func_handle_);
auto it = func_handle_map_.find(fname);
llvm::GlobalVariable* hptr;
if (it == func_handle_map_.end()) {
// create global location for the handle
// create the function handle
hptr = new llvm::GlobalVariable(
*module_, t_tvm_func_handle_, false,
llvm::GlobalValue::PrivateLinkage, 0, ".tvm_func");
hptr->setAlignment(halign);
hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_));
func_handle_map_[fname] = hptr;
} else {
hptr = it->second;
}
// create emit codes that checks and load the function.
BasicBlock* pre_block = builder_->GetInsertBlock();
BasicBlock* init_block = BasicBlock::Create(
*ctx_, "handle_init", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "handle_init_end", function_);
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, halign);
llvm::Value* handle_not_null = builder_->CreateICmpNE(
handle, llvm::Constant::getNullValue(t_tvm_func_handle_));
builder_->CreateCondBr(
handle_not_null, end_block, init_block, md_very_likely_branch_);
// loaded handle, if created by call.
llvm::Value* loaded_handle = nullptr;
// Then block.
// We do not do lock here, so unlike static variable initialization
// This clause might be executed multiple times, but it is safe to do so.
builder_->SetInsertPoint(init_block);
if (global) {
llvm::Value* out = builder_->CreateAlloca(t_tvm_func_handle_);
llvm::Value* retcode = builder_->CreateCall(
f_tvm_func_get_global_, {GetConstString(fname), out});
init_block = CheckPackedCallSuccess(retcode);
loaded_handle = builder_->CreateAlignedLoad(out, halign);
} else {
LOG(FATAL) << "not yet supported";
}
builder_->CreateBr(end_block);
// end block
builder_->SetInsertPoint(end_block);
llvm::PHINode* phi = builder_->CreatePHI(t_tvm_func_handle_, 2);
phi->addIncoming(handle, pre_block);
phi->addIncoming(loaded_handle, init_block);
return phi;
}
llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
CHECK_GE(op->args.size(), 1U);
std::string func_name = op->args[0].as<StringImm>()->value;
CHECK(!op->is_intrinsic(intrinsic::tvm_call_device))
<< "not implemented for now";
llvm::Value* handle = GetPackedFuncHandle(
func_name, op->is_intrinsic(intrinsic::tvm_call_global));
// call the function
unsigned nargs = static_cast<unsigned>(op->args.size() - 1);
llvm::Value* targs = builder_->CreateAlloca(
t_tvm_value_, ConstInt32(nargs));
llvm::Value* tcodes = builder_->CreateAlloca(
t_int_, ConstInt32(nargs));
for (unsigned i = 0; i < nargs; ++i) {
Expr expr = op->args[i + 1];
Type t = expr.type();
CHECK_EQ(t.lanes(), 1);
// Always pass via 64 bit value.
// For handle type, Handle(64) maps to 32 bit void* in 32bit platform.
Type api_type = t.with_bits(64);
llvm::Value* value = CreateCast(t, api_type, MakeValue(expr));
llvm::Value* store_ptr = builder_->CreatePointerCast(
builder_->CreateInBoundsGEP(targs, ConstInt32(i)),
LLVMType(api_type)->getPointerTo());
builder_->CreateAlignedStore(value, store_ptr, 8);
builder_->CreateAlignedStore(
ConstInt32(t.code()),
builder_->CreateInBoundsGEP(tcodes, ConstInt32(i)), 4);
}
llvm::Value* ret_value = builder_->CreateAlloca(t_tvm_value_);
llvm::Value* ret_tcode = builder_->CreateAlloca(t_int_);
CheckPackedCallSuccess(
builder_->CreateCall(
f_tvm_func_call_,
{handle, targs, tcodes, ConstInt32(nargs), ret_value, ret_tcode}));
Type r_type = op->type;
Type r_api_type = op->type.with_bits(64);
llvm::Value* rvalue =
builder_->CreateAlignedLoad(
builder_->CreatePointerCast(
ret_value, LLVMType(r_api_type)->getPointerTo()), 8);
rvalue = CreateCast(r_api_type, r_type, rvalue);
return rvalue;
}
llvm::Value* CodeGenLLVM::CreateCallExtern(const Call* op) {
std::vector<llvm::Value*> arg_values(op->args.size());
for (size_t i = 0; i < op->args.size(); ++i) {
arg_values[i] = MakeValue(op->args[i]);
}
if (op->type.is_scalar()) {
llvm::Function* f = module_->getFunction(op->name);
if (f) {
return builder_->CreateCall(f, arg_values);
} else {
LOG(FATAL) << "cannot find function " << op->name;
}
} else {
llvm::Function* f = module_->getFunction(op->name);
if (f) {
return CreateScalarizedCall(op, f, arg_values);
} else {
LOG(FATAL) << "cannot find function " << op->name;
}
}
return nullptr;
}
llvm::Value* CodeGenLLVM::CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args) {
llvm::Value* value = llvm::UndefValue::get(LLVMType(op->type));
for (int i = 0; i < op->type.lanes(); ++i) {
std::vector<llvm::Value*> sargs(args.size());
for (size_t j = 0; j < args.size(); ++j) {
if (args[j]->getType()->isVectorTy()) {
sargs[j] = builder_->CreateExtractElement(args[j], ConstInt32(i));
} else {
sargs[j] = args[j];
}
}
llvm::CallInst* call = builder_->CreateCall(f, sargs);
if (op->is_pure()) {
call->setDoesNotAccessMemory();
}
call->setDoesNotThrow();
if (!call->getType()->isVoidTy()) {
value = builder_->CreateInsertElement(value, call, ConstInt32(i));
}
}
return value;
}
llvm::Value* CodeGenLLVM::GetVarValue(const Variable* v) const {
auto it = var_map_.find(v);
CHECK(it != var_map_.end())
<< "Cannot find " << v->name_hint << " in the var map";
return it->second;
}
llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
auto it = str_map_.find(str);
if (it == str_map_.end()) {
llvm::Type* type = llvm::ArrayType::get(t_char_, str.length() + 1);
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, true, llvm::GlobalValue::PrivateLinkage, 0, ".str");
global->setAlignment(1);
global->setInitializer(llvm::ConstantDataArray::getString(*ctx_, str));
// useful constant value
llvm::Constant* zero = ConstInt32(0);
llvm::Constant* indices[] = {zero, zero};
llvm::Constant* sptr = llvm::ConstantExpr::getGetElementPtr(
type, global, indices);
str_map_[str] = sptr;
return sptr;
} else {
return it->second;
}
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_llvm.h
* \brief Common base class for generating into LLVM IR
*/
#ifndef TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#define TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#ifdef TVM_LLVM_VERSION
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/codegen.h>
#include <memory>
#include <vector>
#include <string>
#include "./llvm_common.h"
namespace tvm {
namespace codegen {
using namespace ir;
/*!
* \brief A base class to generate a LLVM.
*/
class CodeGenLLVM : public IRVisitor {
public:
/*!
* \brief Initialize the code generator with given context
* \param module_name The name of the module.
* \param ctx The context.
*/
void Init(const std::string& module_name, llvm::LLVMContext* ctx);
/*!
* \brief Compile and add function f to the current module.
* \param f The function to be added.
*/
void AddFunction(const LoweredFunc& f);
/*!
* \brief Finish current pass of codegen, get the module.
* \return the created module.
*/
std::unique_ptr<llvm::Module> Finish();
/*!
* \brief Create Value for expression e
* \param e The expression to be created value for.
* \return created value.
*/
llvm::Value* MakeValue(const Expr& e) {
value_ = nullptr;
this->Visit(e);
CHECK(value_ != nullptr);
return value_;
}
// Short hande code to get a constant int 32
llvm::Constant* ConstInt32(unsigned value) const {
return llvm::ConstantInt::get(t_int32_, value);
}
// override codegen
void Visit_(const Variable* op) final;
void Visit_(const Cast* op) final;
void Visit_(const IntImm* op) final;
void Visit_(const UIntImm* op) final;
void Visit_(const FloatImm* op) final;
void Visit_(const StringImm* op) final;
void Visit_(const Add* op) final;
void Visit_(const Sub* op) final;
void Visit_(const Mul* op) final;
void Visit_(const Div* op) final;
void Visit_(const Mod* op) final;
void Visit_(const Min* op) final;
void Visit_(const Max* op) final;
void Visit_(const LT* op) final;
void Visit_(const LE* op) final;
void Visit_(const GT* op) final;
void Visit_(const GE* op) final;
void Visit_(const EQ* op) final;
void Visit_(const NE* op) final;
void Visit_(const And* op) final;
void Visit_(const Or* op) final;
void Visit_(const Not* op) final;
void Visit_(const Select* op) final;
void Visit_(const Let* op) final;
void Visit_(const Load* op) final;
void Visit_(const Call* op) final;
void Visit_(const Ramp* op) final;
void Visit_(const Broadcast* op) final;
// stmt
void Visit_(const Store* op) final;
void Visit_(const For* op) final;
void Visit_(const IfThenElse* op) final;
void Visit_(const Allocate* op) final;
void Visit_(const AttrStmt* op) override;
void Visit_(const AssertStmt* op) final;
void Visit_(const LetStmt* op) final;
// create intrinstic given call
virtual llvm::Value* CreateIntrinstic(const Call* op);
// create extern function call
virtual llvm::Value* CreateCallExtern(const Call* op);
// create call into tvm packed function.
virtual llvm::Value* CreateCallPacked(const Call* op);
protected:
/*!
* \param t The original type.
* \return LLVM type of t
*/
llvm::Type* LLVMType(const Type& t) const;
// do a scalarize call with f
llvm::Value* CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
// apply optimization on the module.
virtual void Optimize();
// The IRBuilder.
using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
// The current function
llvm::Function* function_;
// Internal builder
std::unique_ptr<IRBuilder> builder_;
// The module to be returned;
std::unique_ptr<llvm::Module> module_;
// Internal metabuilder
std::unique_ptr<llvm::MDBuilder> md_builder_;
// llvm context
llvm::LLVMContext* ctx_{nullptr};
// helpful data types
llvm::Type* t_void_{nullptr};
llvm::Type* t_void_p_{nullptr};
llvm::Type* t_int_{nullptr};
llvm::Type* t_char_{nullptr};
llvm::Type* t_int8_{nullptr};
llvm::Type* t_int16_{nullptr};
llvm::Type* t_int32_{nullptr};
llvm::Type* t_int64_{nullptr};
llvm::Type* t_float64_{nullptr};
// branch
llvm::MDNode* md_very_likely_branch_{nullptr};
llvm::MDNode* md_tbaa_root_{nullptr};
// TVM related data types
llvm::Type* t_tvm_index_{nullptr};
llvm::Type* t_tvm_func_handle_{nullptr};
llvm::StructType* t_tvm_context_{nullptr};
llvm::StructType* t_tvm_type_{nullptr};
llvm::StructType* t_tvm_array_{nullptr};
llvm::StructType* t_tvm_value_{nullptr};
// tvm api functions
llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_func_get_global_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr};
// The acting body
llvm::BasicBlock* block_{nullptr};
// Last value returned codegen call.
llvm::Value* value_{nullptr};
private:
// comparison op
llvm::Value* GetVarValue(const Variable* v) const;
llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateLE(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateGT(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateGE(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateAdd(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateSub(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
llvm::Value* GetConstString(const std::string& str);
llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
llvm::Value* GetPackedFuncHandle(const std::string& str, bool global);
// Check if the call to packed function is successful
// if not directly finalize function and pass on return code.
// return the end block after the check
llvm::BasicBlock* CheckPackedCallSuccess(llvm::Value* retcode);
// add alias information.
void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index);
// The definition of local variable.
std::unordered_map<const Variable*, llvm::Value*> var_map_;
// global strings
std::unordered_map<std::string, llvm::Constant*> str_map_;
// global to packed function handle
std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;
};
} // namespace codegen
} // namespace tvm
#endif // LLVM_VERSION
#endif // TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
/*!
* Copyright (c) 2017 by Contributors
* \file llvm_common.cc
*/
#ifdef TVM_LLVM_VERSION
#include <tvm/base.h>
#include <mutex>
#include "./llvm_common.h"
namespace tvm {
namespace codegen {
struct LLVMEnv {
std::mutex mu;
bool native_initialized{false};
static LLVMEnv* Global() {
static LLVMEnv inst;
return &inst;
}
};
void InitializeLLVM() {
LLVMEnv* e = LLVMEnv::Global();
if (!e->native_initialized) {
std::lock_guard<std::mutex>(e->mu);
if (!e->native_initialized) {
e->native_initialized = true;
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();
}
}
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
/*!
* Copyright (c) 2017 by Contributors
* \file llvm_common.h
* \brief Common utilities for llvm initialization.
*/
#ifndef TVM_CODEGEN_LLVM_LLVM_COMMON_H_
#define TVM_CODEGEN_LLVM_LLVM_COMMON_H_
#ifdef TVM_LLVM_VERSION
#include <llvm/ExecutionEngine/MCJIT.h>
#include <llvm/IR/Value.h>
#include <llvm/IR/Argument.h>
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Type.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/IR/MDBuilder.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
#include <llvm/Transforms/IPO.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/TargetSelect.h>
extern "C" {
// Function signature for LLVM generated packed function.
typedef int (*LLVMPackedCFunc)(void* args,
int* type_codes,
int num_args);
} // extern "C"
namespace tvm {
namespace codegen {
/*!
* \brief Initialize LLVM on this process,
* can be called multiple times.
*/
void InitializeLLVM();
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
#endif // TVM_CODEGEN_LLVM_LLVM_COMMON_H_
/*!
* Copyright (c) 2017 by Contributors
* \file llvm_exec_engine.cc
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/codegen.h>
#include "./llvm_common.h"
#include "./codegen_llvm.h"
namespace tvm {
namespace codegen {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
#ifdef TVM_LLVM_VERSION
// Environment to keep jit resources alive.
struct LLVMJITEnv {
std::shared_ptr<llvm::LLVMContext> ctx;
llvm::ExecutionEngine* ee{nullptr};
// constructor
LLVMJITEnv(std::shared_ptr<llvm::LLVMContext> ctx,
llvm::ExecutionEngine* ee)
: ctx(ctx), ee(ee) {
}
// destructor
~LLVMJITEnv() {
if (ee != nullptr) {
ee->runStaticConstructorsDestructors(true);
delete ee;
}
}
};
PackedFunc JITCompile(std::unique_ptr<llvm::Module> module,
std::shared_ptr<llvm::LLVMContext> ctx,
const std::string& func_name) {
llvm::EngineBuilder builder(std::move(module));
builder.setEngineKind(llvm::EngineKind::JIT);
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
std::shared_ptr<LLVMJITEnv> env = std::make_shared<LLVMJITEnv>(
ctx, builder.create());
CHECK(env->ee != nullptr);
auto* faddr = reinterpret_cast<LLVMPackedCFunc>(
env->ee->getFunctionAddress(func_name));
env->ee->runStaticConstructorsDestructors(false);
return PackedFunc([env, faddr](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)(
(void*)args.values, // NOLINT(*)
(int*)args.type_codes, // NOLINT(*)
args.num_args);
CHECK(ret == 0) << TVMGetLastError();
});
}
PackedFunc BuildLLVM(LoweredFunc func) {
InitializeLLVM();
// use one context per function.
std::shared_ptr<llvm::LLVMContext> ctx =
std::make_shared<llvm::LLVMContext>();
CodeGenLLVM cg;
cg.Init(func->name, ctx.get());
cg.AddFunction(func);
std::unique_ptr<llvm::Module> m = cg.Finish();
return JITCompile(std::move(m), ctx, func->name);
}
#else
PackedFunc BuildLLVM(LoweredFunc func) {
LOG(FATAL) << "LLVM is not enabled";
return PackedFunc();
}
#endif // TVM_LLVM_VERSION
} // namespace codegen
} // namespace tvm
...@@ -21,8 +21,6 @@ ...@@ -21,8 +21,6 @@
*/ */
#define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*) #define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
void TVMAPISetLastError(const char* msg);
/*! /*!
* \brief handle exception throwed out * \brief handle exception throwed out
* \param e the exception * \param e the exception
......
...@@ -274,8 +274,9 @@ Stmt MakeLoop(const Stage& s, ...@@ -274,8 +274,9 @@ Stmt MakeLoop(const Stage& s,
bound_state[iv] = false; bound_state[iv] = false;
} }
PassUpBoundCheck(s, dom_map, &bound_state); PassUpBoundCheck(s, dom_map, &bound_state);
auto nest = MakeLoopNest(s, dom_map, 0, false, auto nest = MakeLoopNest(
bound_state, {}, &value_map); s, dom_map, 0, false,
bound_state, {{}}, &value_map);
provide = Substitute(provide, value_map); provide = Substitute(provide, value_map);
if (init.defined()) { if (init.defined()) {
......
...@@ -2,7 +2,6 @@ import tvm ...@@ -2,7 +2,6 @@ import tvm
import numpy as np import numpy as np
def test_add_pipeline(): def test_add_pipeline():
"""Not yet working, mock design"""
n = tvm.Var('n') n = tvm.Var('n')
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
......
...@@ -6,6 +6,16 @@ def tvm_call_global(*args): ...@@ -6,6 +6,16 @@ def tvm_call_global(*args):
return tvm.make.Call("int32", "tvm_call_global", args, 4, None, 0) return tvm.make.Call("int32", "tvm_call_global", args, 4, None, 0)
def run_jit(fapi, check):
for target in ["stackvm"]:
if target == "llvm":
f = tvm.codegen.BuildLLVM(fapi)
else:
f = tvm.codegen.BuildStackVM(fapi)
check(f)
def test_stack_vm_basic(): def test_stack_vm_basic():
a = tvm.nd.array(np.zeros(10, dtype='float32')) a = tvm.nd.array(np.zeros(10, dtype='float32'))
@tvm.register_func @tvm.register_func
...@@ -17,8 +27,7 @@ def test_stack_vm_basic(): ...@@ -17,8 +27,7 @@ def test_stack_vm_basic():
Ab = tvm.Buffer((n, ), tvm.float32) Ab = tvm.Buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm_call_global("tvm_call_back_get_shape", Ab.shape[0])) stmt = tvm.make.Evaluate(tvm_call_global("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 1) fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi) run_jit(fapi, lambda f: f(a))
f(a)
@tvm.register_func @tvm.register_func
...@@ -42,8 +51,10 @@ def test_stack_vm_loop(): ...@@ -42,8 +51,10 @@ def test_stack_vm_loop():
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 1) fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi) f = tvm.codegen.BuildStackVM(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) def check(f):
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0])) f(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
run_jit(fapi, check)
def test_stack_vm_cond(): def test_stack_vm_cond():
...@@ -61,15 +72,46 @@ def test_stack_vm_cond(): ...@@ -61,15 +72,46 @@ def test_stack_vm_cond():
tvm.make.Store(Ab.data, tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 2, i + 1))) tvm.make.Load(dtype, Ab.data, i) + 2, i + 1)))
fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 1) fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi) def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
y = np.arange(a.shape[0]) * 2 y = np.arange(a.shape[0]) * 2
y[5:] -= 1 y[5:] -= 1
np.testing.assert_equal(a.asnumpy(), y) np.testing.assert_equal(a.asnumpy(), y)
run_jit(fapi, check)
def test_llvm_add_pipeline():
n = tvm.Var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.Schedule(C.op)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 3)
def check_llvm():
# build and invoke the kernel.
f = tvm.codegen.BuildLLVM(fapi)
ctx = tvm.cpu(0)
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx)
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
#check_llvm()
if __name__ == "__main__": if __name__ == "__main__":
test_stack_vm_cond() test_stack_vm_cond()
test_stack_vm_loop()
test_stack_vm_basic() test_stack_vm_basic()
test_stack_vm_loop()
test_llvm_add_pipeline()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment