Commit 8f240ee7 by Tianqi Chen Committed by GitHub

[CODEGEN/LLVM] Initial support for codegen LLVM. (#49)

* [LLVM] Initial support for codegen LLVM.

* Fix the naming issue of codegen
parent 3555769e
Subproject commit e68ae61cd541ac29efc9fafe2ad061479bcaa9c9
Subproject commit 1a11a6c2522b1d11a5ccdb9b4fe3976cbe7f9f27
ifndef config
ifneq ("$(wildcard ./config.mk)","")
config = config.mk
config ?= config.mk
else
config = make/config.mk
config ?= make/config.mk
endif
endif
......@@ -19,24 +19,16 @@ SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
ifneq ($(USE_CUDA_PATH), NONE)
NVCC=$(USE_CUDA_PATH)/bin/nvcc
endif
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC
export FRAMEWORKS=
ifneq ($(ADD_CFLAGS), NONE)
CFLAGS += $(ADD_CFLAGS)
endif
export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
ifneq ($(ADD_LDFLAGS), NONE)
LDFLAGS += $(ADD_LDFLAGS)
ifdef CUDA_PATH
NVCC=$(CUDA_PATH)/bin/nvcc
CFLAGS += -I$(CUDA_PATH)/include
LDFLAGS += -L$(CUDA_PATH)/lib64
endif
ifeq ($(USE_CUDA), 1)
CFLAGS += -DTVM_CUDA_RUNTIME=1
LDFLAGS += -lcuda -lcudart -lnvrtc
......@@ -44,6 +36,7 @@ else
CFLAGS += -DTVM_CUDA_RUNTIME=0
endif
FRAMEWORKS=
ifeq ($(USE_OPENCL), 1)
CFLAGS += -DTVM_OPENCL_RUNTIME=1
......@@ -57,6 +50,23 @@ else
CFLAGS += -DTVM_OPENCL_RUNTIME=0
endif
# llvm configuration
LLVM_CONFIG=llvm-config
ifeq ($(USE_LLVM), 1)
LLVM_VERSION=$(shell $(LLVM_CONFIG) --version| cut -b 1,3)
LLVM_INCLUDE=$(filter -I%, $(shell $(LLVM_CONFIG) --cxxflags))
LDFLAGS += $(shell $(LLVM_CONFIG) --ldflags --libs --system-libs)
CFLAGS += $(LLVM_INCLUDE) -DTVM_LLVM_VERSION=$(LLVM_VERSION)
endif
ifdef $(ADD_CFLAGS)
CFLAGS += $(ADD_CFLAGS)
endif
ifdef $(ADD_LDFLAGS)
LDFLAGS += $(ADD_LDFLAGS)
endif
include tests/cpp/unittest.mk
......
Subproject commit 3a51614d39b69fdb5de1efcf1016426626d267a6
Subproject commit 8dd365636528175e785448cf8a9f4e494c8ee0e0
......@@ -90,7 +90,7 @@ class BufferNode : public Node {
Type dtype);
static constexpr const char* _type_key = "Buffer";
TVM_DECLARE_NODE_TYPE_INFO(BufferNode);
TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node);
};
inline const BufferNode* Buffer::operator->() const {
......
......@@ -32,6 +32,13 @@ PackedFunc BuildStackVM(
const std::unordered_map<LoweredFunc, PackedFunc>& device_funcs);
/*!
* \brief Build a LLVM VM function, this is still beta
* \param func The LoweredFunc to be build
* \return A packed function representing the func.
*/
PackedFunc BuildLLVM(LoweredFunc func);
/*!
* \brief Build a CUDA function with NVRTC
*
* \param fsplits The LoweredFuncs to be build (after SplitHostDevice)
......
......@@ -36,7 +36,6 @@ using Halide::Internal::make_zero;
using Halide::Internal::as_const_int;
using Halide::Internal::as_const_uint;
inline Type TVMType2Type(TVMType t) {
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
}
......@@ -182,7 +181,7 @@ class IterVarNode : public Node {
static IterVar make(Range dom, Var var, std::string thread_tag);
static constexpr const char* _type_key = "IterVar";
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode);
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node);
};
// inline implementations
......
......@@ -200,6 +200,8 @@ using Halide::Internal::Realize;
using Halide::Internal::Block;
using Halide::Internal::IfThenElse;
using Halide::Internal::Evaluate;
// ir functions
using Halide::Internal::is_const_power_of_two_integer;
} // namespace ir
} // namespace tvm
......
......@@ -92,7 +92,7 @@ class LoweredFuncNode : public FunctionBaseNode {
}
static constexpr const char* _type_key = "LoweredFunc";
TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode);
TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode, Node);
};
// Implementations of inline functions
......
......@@ -39,7 +39,7 @@ class PlaceholderOpNode : public OperationNode {
Type dtype);
static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode);
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode);
};
/*!
......@@ -74,7 +74,7 @@ class ComputeOpNode : public OperationNode {
Expr body);
static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode);
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
};
/*!
......@@ -123,7 +123,7 @@ class ScanOpNode : public OperationNode {
Array<Tensor> state_placeholder);
static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode);
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode);
};
......
......@@ -33,7 +33,7 @@ struct NodeTypeChecker {
// It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType;
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
return sptr->derived_from<ContainerType>();
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
......
......@@ -153,6 +153,13 @@ typedef void* TVMRetValueHandle;
typedef TVMArray* TVMArrayHandle;
/*!
* \brief Used for implementing C API function.
* Set last error message before return.
* \param msg The error message to be set.
*/
TVM_DLL void TVMAPISetLastError(const char* msg);
/*!
* \brief return str message of the last error
* all function in this file will return 0 when success
* and -1 when an error occured,
......@@ -287,10 +294,10 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret,
* \param num_args Number of arguments.
* \param ret The return value handle.
* \param resource_handle The handle additional resouce handle from fron-end.
*
* \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
* \sa TVMCFuncSetReturn
*/
typedef void (*TVMPackedCFunc)(
typedef int (*TVMPackedCFunc)(
TVMValue* args,
int* type_codes,
int num_args,
......
......@@ -331,7 +331,7 @@ class StageNode : public Node {
}
static constexpr const char* _type_key = "Stage";
TVM_DECLARE_NODE_TYPE_INFO(StageNode);
TVM_DECLARE_NODE_TYPE_INFO(StageNode, Node);
};
/*! \brief node container for schedule */
......@@ -354,7 +354,7 @@ class ScheduleNode : public Node {
}
static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode);
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
};
/*! \brief node container for IterVar attr */
......@@ -368,11 +368,14 @@ class IterVarAttrNode : public Node {
}
static constexpr const char* _type_key = "IterVarAttr";
TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode);
TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode, Node);
};
/*! \brief base node of iteration var */
class IterVarRelationNode : public Node {
public:
static constexpr const char* _type_key = "IterVarRelation";
TVM_DECLARE_BASE_NODE_INFO(IterVarRelationNode, Node);
};
/*!
......@@ -402,7 +405,7 @@ class SplitNode : public IterVarRelationNode {
IterVar inner, Expr factor);
static constexpr const char* _type_key = "Split";
TVM_DECLARE_NODE_TYPE_INFO(SplitNode);
TVM_DECLARE_NODE_TYPE_INFO(SplitNode, IterVarRelationNode);
};
/*!
......@@ -427,7 +430,7 @@ class FuseNode : public IterVarRelationNode {
IterVar outer, IterVar inner, IterVar fused);
static constexpr const char* _type_key = "Fuse";
TVM_DECLARE_NODE_TYPE_INFO(FuseNode);
TVM_DECLARE_NODE_TYPE_INFO(FuseNode, IterVarRelationNode);
};
/*!
......@@ -450,7 +453,7 @@ class RebaseNode : public IterVarRelationNode {
static IterVarRelation make(IterVar parent, IterVar rebased);
static constexpr const char* _type_key = "Rebase";
TVM_DECLARE_NODE_TYPE_INFO(RebaseNode);
TVM_DECLARE_NODE_TYPE_INFO(RebaseNode, IterVarRelationNode);
};
......
......@@ -153,7 +153,7 @@ class TensorNode : public Node {
int value_index);
static constexpr const char* _type_key = "Tensor";
TVM_DECLARE_NODE_TYPE_INFO(TensorNode);
TVM_DECLARE_NODE_TYPE_INFO(TensorNode, Node);
};
/*!
......@@ -167,8 +167,6 @@ class OperationNode : public FunctionBaseNode {
const std::string& func_name() const final {
return name;
}
/*! \return number of outputs of this op */
virtual int num_outputs() const = 0;
/*! \return the list of iteration variable at root */
virtual Array<IterVar> root_iter_vars() const = 0;
/*! \return type of i-th output */
......@@ -177,6 +175,8 @@ class OperationNode : public FunctionBaseNode {
virtual Array<Expr> output_shape(size_t i) const = 0;
static constexpr const char* _type_key = "Operation";
TVM_DECLARE_BASE_NODE_INFO(OperationNode, Node);
};
// Implementations of inline functions
......
......@@ -40,10 +40,11 @@ USE_CUDA = 1
# whether use OpenCL during compile
USE_OPENCL = 0
# add the path to CUDA library to link and compile flag
# if you have already add them to environment variable, leave it as NONE
# USE_CUDA_PATH = /usr/local/cuda
USE_CUDA_PATH = NONE
# whether build with LLVM support
# This requires llvm-config to be in your PATH
# Requires LLVM version >= 4.0
USE_LLVM = 0
# whether use cuda runtime compiling for writing kernels in native language (i.e. Python)
USE_NVRTC = 0
# add the path to CUDA library to link and compile flag
# if you have already add them to environment variable.
# CUDA_PATH = /usr/local/cuda
......@@ -56,6 +56,7 @@ def convert_to_tvm_func(pyfunc):
check_call(_LIB.TVMCFuncSetReturn(ret, values[0], ctypes.c_int(tcodes[0])))
_ = temp_args
_ = rv
return 0
handle = FunctionHandle()
f = TVMPackedCFunc(cfun)
......
......@@ -96,7 +96,7 @@ class TVMByteArray(ctypes.Structure):
TVMPackedCFunc = ctypes.CFUNCTYPE(
None,
ctypes.c_int,
ctypes.POINTER(TVMValue),
ctypes.POINTER(ctypes.c_int),
ctypes.c_int,
......
......@@ -6,3 +6,4 @@
- arithmetic Arithmetic expression and set simplification
- pass The optimization pass on the IR structure
- runtime Minimum runtime related codes.
- codegen The code generator
......@@ -37,6 +37,11 @@ TVM_REGISTER_API(_codegen_BuildStackVM)
std::unordered_map<LoweredFunc, PackedFunc>());
});
TVM_REGISTER_API(_codegen_BuildLLVM)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = BuildLLVM(args[0]);
});
TVM_REGISTER_API(_codegen_BuildNVRTC)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = BuildNVRTC(args[0], args[1]);
......
......@@ -20,7 +20,7 @@ enum SignType {
};
// internal node container of int set.
class IntSetNode;
struct IntSetNode;
/*!
* \brief Integer set class, represent a set of integers in one dimension.
......@@ -104,6 +104,8 @@ class IntSet : public NodeRef {
* \brief Base class of all IntSet containers.
*/
struct IntSetNode : public Node {
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
};
using ExprIntSetMap = std::unordered_map<Expr, IntSet,
......
......@@ -35,7 +35,7 @@ struct IntervalSet : public IntSetNode {
}
static constexpr const char* _type_key = "IntervalSet";
TVM_DECLARE_NODE_TYPE_INFO(IntervalSet);
TVM_DECLARE_NODE_TYPE_INFO(IntervalSet, IntSetNode);
};
/*!
......@@ -51,7 +51,7 @@ struct StrideSet : public IntSetNode {
Array<Expr> strides;
static constexpr const char* _type_key = "StrideSet";
TVM_DECLARE_NODE_TYPE_INFO(StrideSet);
TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode);
};
} // namespace arith
......
......@@ -272,9 +272,6 @@ inline void PushBinary(StackVM::OpCode op_int64,
}
}
inline void PushCast(Type dst,
Type src,
CodeGenStackVM* p) {
......@@ -496,7 +493,5 @@ TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
.set_dispatch<Call>([](const Call *op, CodeGenStackVM* p) {
p->Push_(op);
});
} // namespace codegen
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_llvm.h
* \brief Common base class for generating into LLVM IR
*/
#ifndef TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#define TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#ifdef TVM_LLVM_VERSION
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/codegen.h>
#include <memory>
#include <vector>
#include <string>
#include "./llvm_common.h"
namespace tvm {
namespace codegen {
using namespace ir;
/*!
* \brief A base class to generate a LLVM.
*/
class CodeGenLLVM : public IRVisitor {
public:
/*!
* \brief Initialize the code generator with given context
* \param module_name The name of the module.
* \param ctx The context.
*/
void Init(const std::string& module_name, llvm::LLVMContext* ctx);
/*!
* \brief Compile and add function f to the current module.
* \param f The function to be added.
*/
void AddFunction(const LoweredFunc& f);
/*!
* \brief Finish current pass of codegen, get the module.
* \return the created module.
*/
std::unique_ptr<llvm::Module> Finish();
/*!
* \brief Create Value for expression e
* \param e The expression to be created value for.
* \return created value.
*/
llvm::Value* MakeValue(const Expr& e) {
value_ = nullptr;
this->Visit(e);
CHECK(value_ != nullptr);
return value_;
}
// Short hande code to get a constant int 32
llvm::Constant* ConstInt32(unsigned value) const {
return llvm::ConstantInt::get(t_int32_, value);
}
// override codegen
void Visit_(const Variable* op) final;
void Visit_(const Cast* op) final;
void Visit_(const IntImm* op) final;
void Visit_(const UIntImm* op) final;
void Visit_(const FloatImm* op) final;
void Visit_(const StringImm* op) final;
void Visit_(const Add* op) final;
void Visit_(const Sub* op) final;
void Visit_(const Mul* op) final;
void Visit_(const Div* op) final;
void Visit_(const Mod* op) final;
void Visit_(const Min* op) final;
void Visit_(const Max* op) final;
void Visit_(const LT* op) final;
void Visit_(const LE* op) final;
void Visit_(const GT* op) final;
void Visit_(const GE* op) final;
void Visit_(const EQ* op) final;
void Visit_(const NE* op) final;
void Visit_(const And* op) final;
void Visit_(const Or* op) final;
void Visit_(const Not* op) final;
void Visit_(const Select* op) final;
void Visit_(const Let* op) final;
void Visit_(const Load* op) final;
void Visit_(const Call* op) final;
void Visit_(const Ramp* op) final;
void Visit_(const Broadcast* op) final;
// stmt
void Visit_(const Store* op) final;
void Visit_(const For* op) final;
void Visit_(const IfThenElse* op) final;
void Visit_(const Allocate* op) final;
void Visit_(const AttrStmt* op) override;
void Visit_(const AssertStmt* op) final;
void Visit_(const LetStmt* op) final;
// create intrinstic given call
virtual llvm::Value* CreateIntrinstic(const Call* op);
// create extern function call
virtual llvm::Value* CreateCallExtern(const Call* op);
// create call into tvm packed function.
virtual llvm::Value* CreateCallPacked(const Call* op);
protected:
/*!
* \param t The original type.
* \return LLVM type of t
*/
llvm::Type* LLVMType(const Type& t) const;
// do a scalarize call with f
llvm::Value* CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
// apply optimization on the module.
virtual void Optimize();
// The IRBuilder.
using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
// The current function
llvm::Function* function_;
// Internal builder
std::unique_ptr<IRBuilder> builder_;
// The module to be returned;
std::unique_ptr<llvm::Module> module_;
// Internal metabuilder
std::unique_ptr<llvm::MDBuilder> md_builder_;
// llvm context
llvm::LLVMContext* ctx_{nullptr};
// helpful data types
llvm::Type* t_void_{nullptr};
llvm::Type* t_void_p_{nullptr};
llvm::Type* t_int_{nullptr};
llvm::Type* t_char_{nullptr};
llvm::Type* t_int8_{nullptr};
llvm::Type* t_int16_{nullptr};
llvm::Type* t_int32_{nullptr};
llvm::Type* t_int64_{nullptr};
llvm::Type* t_float64_{nullptr};
// branch
llvm::MDNode* md_very_likely_branch_{nullptr};
llvm::MDNode* md_tbaa_root_{nullptr};
// TVM related data types
llvm::Type* t_tvm_index_{nullptr};
llvm::Type* t_tvm_func_handle_{nullptr};
llvm::StructType* t_tvm_context_{nullptr};
llvm::StructType* t_tvm_type_{nullptr};
llvm::StructType* t_tvm_array_{nullptr};
llvm::StructType* t_tvm_value_{nullptr};
// tvm api functions
llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_func_get_global_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr};
// The acting body
llvm::BasicBlock* block_{nullptr};
// Last value returned codegen call.
llvm::Value* value_{nullptr};
private:
// comparison op
llvm::Value* GetVarValue(const Variable* v) const;
llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateLE(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateGT(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateGE(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateAdd(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateSub(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
llvm::Value* GetConstString(const std::string& str);
llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
llvm::Value* GetPackedFuncHandle(const std::string& str, bool global);
// Check if the call to packed function is successful
// if not directly finalize function and pass on return code.
// return the end block after the check
llvm::BasicBlock* CheckPackedCallSuccess(llvm::Value* retcode);
// add alias information.
void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index);
// The definition of local variable.
std::unordered_map<const Variable*, llvm::Value*> var_map_;
// global strings
std::unordered_map<std::string, llvm::Constant*> str_map_;
// global to packed function handle
std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;
};
} // namespace codegen
} // namespace tvm
#endif // LLVM_VERSION
#endif // TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
/*!
* Copyright (c) 2017 by Contributors
* \file llvm_common.cc
*/
#ifdef TVM_LLVM_VERSION
#include <tvm/base.h>
#include <mutex>
#include "./llvm_common.h"
namespace tvm {
namespace codegen {
struct LLVMEnv {
std::mutex mu;
bool native_initialized{false};
static LLVMEnv* Global() {
static LLVMEnv inst;
return &inst;
}
};
void InitializeLLVM() {
LLVMEnv* e = LLVMEnv::Global();
if (!e->native_initialized) {
std::lock_guard<std::mutex>(e->mu);
if (!e->native_initialized) {
e->native_initialized = true;
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();
}
}
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
/*!
* Copyright (c) 2017 by Contributors
* \file llvm_common.h
* \brief Common utilities for llvm initialization.
*/
#ifndef TVM_CODEGEN_LLVM_LLVM_COMMON_H_
#define TVM_CODEGEN_LLVM_LLVM_COMMON_H_
#ifdef TVM_LLVM_VERSION
#include <llvm/ExecutionEngine/MCJIT.h>
#include <llvm/IR/Value.h>
#include <llvm/IR/Argument.h>
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Type.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/IR/MDBuilder.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
#include <llvm/Transforms/IPO.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/TargetSelect.h>
extern "C" {
// Function signature for LLVM generated packed function.
typedef int (*LLVMPackedCFunc)(void* args,
int* type_codes,
int num_args);
} // extern "C"
namespace tvm {
namespace codegen {
/*!
* \brief Initialize LLVM on this process,
* can be called multiple times.
*/
void InitializeLLVM();
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
#endif // TVM_CODEGEN_LLVM_LLVM_COMMON_H_
/*!
* Copyright (c) 2017 by Contributors
* \file llvm_exec_engine.cc
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/codegen.h>
#include "./llvm_common.h"
#include "./codegen_llvm.h"
namespace tvm {
namespace codegen {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
#ifdef TVM_LLVM_VERSION
// Environment to keep jit resources alive.
struct LLVMJITEnv {
std::shared_ptr<llvm::LLVMContext> ctx;
llvm::ExecutionEngine* ee{nullptr};
// constructor
LLVMJITEnv(std::shared_ptr<llvm::LLVMContext> ctx,
llvm::ExecutionEngine* ee)
: ctx(ctx), ee(ee) {
}
// destructor
~LLVMJITEnv() {
if (ee != nullptr) {
ee->runStaticConstructorsDestructors(true);
delete ee;
}
}
};
PackedFunc JITCompile(std::unique_ptr<llvm::Module> module,
std::shared_ptr<llvm::LLVMContext> ctx,
const std::string& func_name) {
llvm::EngineBuilder builder(std::move(module));
builder.setEngineKind(llvm::EngineKind::JIT);
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
std::shared_ptr<LLVMJITEnv> env = std::make_shared<LLVMJITEnv>(
ctx, builder.create());
CHECK(env->ee != nullptr);
auto* faddr = reinterpret_cast<LLVMPackedCFunc>(
env->ee->getFunctionAddress(func_name));
env->ee->runStaticConstructorsDestructors(false);
return PackedFunc([env, faddr](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)(
(void*)args.values, // NOLINT(*)
(int*)args.type_codes, // NOLINT(*)
args.num_args);
CHECK(ret == 0) << TVMGetLastError();
});
}
PackedFunc BuildLLVM(LoweredFunc func) {
InitializeLLVM();
// use one context per function.
std::shared_ptr<llvm::LLVMContext> ctx =
std::make_shared<llvm::LLVMContext>();
CodeGenLLVM cg;
cg.Init(func->name, ctx.get());
cg.AddFunction(func);
std::unique_ptr<llvm::Module> m = cg.Finish();
return JITCompile(std::move(m), ctx, func->name);
}
#else
PackedFunc BuildLLVM(LoweredFunc func) {
LOG(FATAL) << "LLVM is not enabled";
return PackedFunc();
}
#endif // TVM_LLVM_VERSION
} // namespace codegen
} // namespace tvm
......@@ -21,8 +21,6 @@
*/
#define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
void TVMAPISetLastError(const char* msg);
/*!
* \brief handle exception throwed out
* \param e the exception
......
......@@ -274,8 +274,9 @@ Stmt MakeLoop(const Stage& s,
bound_state[iv] = false;
}
PassUpBoundCheck(s, dom_map, &bound_state);
auto nest = MakeLoopNest(s, dom_map, 0, false,
bound_state, {}, &value_map);
auto nest = MakeLoopNest(
s, dom_map, 0, false,
bound_state, {{}}, &value_map);
provide = Substitute(provide, value_map);
if (init.defined()) {
......
......@@ -2,7 +2,6 @@ import tvm
import numpy as np
def test_add_pipeline():
"""Not yet working, mock design"""
n = tvm.Var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
......
......@@ -6,6 +6,16 @@ def tvm_call_global(*args):
return tvm.make.Call("int32", "tvm_call_global", args, 4, None, 0)
def run_jit(fapi, check):
for target in ["stackvm"]:
if target == "llvm":
f = tvm.codegen.BuildLLVM(fapi)
else:
f = tvm.codegen.BuildStackVM(fapi)
check(f)
def test_stack_vm_basic():
a = tvm.nd.array(np.zeros(10, dtype='float32'))
@tvm.register_func
......@@ -17,8 +27,7 @@ def test_stack_vm_basic():
Ab = tvm.Buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm_call_global("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi)
f(a)
run_jit(fapi, lambda f: f(a))
@tvm.register_func
......@@ -42,8 +51,10 @@ def test_stack_vm_loop():
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
def check(f):
f(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
run_jit(fapi, check)
def test_stack_vm_cond():
......@@ -61,15 +72,46 @@ def test_stack_vm_cond():
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 2, i + 1)))
fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
y = np.arange(a.shape[0]) * 2
y[5:] -= 1
np.testing.assert_equal(a.asnumpy(), y)
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
y = np.arange(a.shape[0]) * 2
y[5:] -= 1
np.testing.assert_equal(a.asnumpy(), y)
run_jit(fapi, check)
def test_llvm_add_pipeline():
n = tvm.Var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.Schedule(C.op)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 3)
def check_llvm():
# build and invoke the kernel.
f = tvm.codegen.BuildLLVM(fapi)
ctx = tvm.cpu(0)
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx)
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
#check_llvm()
if __name__ == "__main__":
test_stack_vm_cond()
test_stack_vm_loop()
test_stack_vm_basic()
test_stack_vm_loop()
test_llvm_add_pipeline()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment