Commit 8f240ee7 by Tianqi Chen Committed by GitHub

[CODEGEN/LLVM] Initial support for codegen LLVM. (#49)

* [LLVM] Initial support for codegen LLVM.

* Fix the naming issue of codegen
parent 3555769e
Subproject commit e68ae61cd541ac29efc9fafe2ad061479bcaa9c9 Subproject commit 1a11a6c2522b1d11a5ccdb9b4fe3976cbe7f9f27
ifndef config ifndef config
ifneq ("$(wildcard ./config.mk)","") ifneq ("$(wildcard ./config.mk)","")
config = config.mk config ?= config.mk
else else
config = make/config.mk config ?= make/config.mk
endif endif
endif endif
...@@ -19,24 +19,16 @@ SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc) ...@@ -19,24 +19,16 @@ SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR) ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
ifneq ($(USE_CUDA_PATH), NONE)
NVCC=$(USE_CUDA_PATH)/bin/nvcc
endif
export LDFLAGS = -pthread -lm export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2\ export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC -Iinclude -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
export FRAMEWORKS=
ifneq ($(ADD_CFLAGS), NONE)
CFLAGS += $(ADD_CFLAGS)
endif
ifneq ($(ADD_LDFLAGS), NONE) ifdef CUDA_PATH
LDFLAGS += $(ADD_LDFLAGS) NVCC=$(CUDA_PATH)/bin/nvcc
CFLAGS += -I$(CUDA_PATH)/include
LDFLAGS += -L$(CUDA_PATH)/lib64
endif endif
ifeq ($(USE_CUDA), 1) ifeq ($(USE_CUDA), 1)
CFLAGS += -DTVM_CUDA_RUNTIME=1 CFLAGS += -DTVM_CUDA_RUNTIME=1
LDFLAGS += -lcuda -lcudart -lnvrtc LDFLAGS += -lcuda -lcudart -lnvrtc
...@@ -44,6 +36,7 @@ else ...@@ -44,6 +36,7 @@ else
CFLAGS += -DTVM_CUDA_RUNTIME=0 CFLAGS += -DTVM_CUDA_RUNTIME=0
endif endif
FRAMEWORKS=
ifeq ($(USE_OPENCL), 1) ifeq ($(USE_OPENCL), 1)
CFLAGS += -DTVM_OPENCL_RUNTIME=1 CFLAGS += -DTVM_OPENCL_RUNTIME=1
...@@ -57,6 +50,23 @@ else ...@@ -57,6 +50,23 @@ else
CFLAGS += -DTVM_OPENCL_RUNTIME=0 CFLAGS += -DTVM_OPENCL_RUNTIME=0
endif endif
# llvm configuration
LLVM_CONFIG=llvm-config
ifeq ($(USE_LLVM), 1)
LLVM_VERSION=$(shell $(LLVM_CONFIG) --version| cut -b 1,3)
LLVM_INCLUDE=$(filter -I%, $(shell $(LLVM_CONFIG) --cxxflags))
LDFLAGS += $(shell $(LLVM_CONFIG) --ldflags --libs --system-libs)
CFLAGS += $(LLVM_INCLUDE) -DTVM_LLVM_VERSION=$(LLVM_VERSION)
endif
ifdef $(ADD_CFLAGS)
CFLAGS += $(ADD_CFLAGS)
endif
ifdef $(ADD_LDFLAGS)
LDFLAGS += $(ADD_LDFLAGS)
endif
include tests/cpp/unittest.mk include tests/cpp/unittest.mk
......
Subproject commit 3a51614d39b69fdb5de1efcf1016426626d267a6 Subproject commit 8dd365636528175e785448cf8a9f4e494c8ee0e0
...@@ -90,7 +90,7 @@ class BufferNode : public Node { ...@@ -90,7 +90,7 @@ class BufferNode : public Node {
Type dtype); Type dtype);
static constexpr const char* _type_key = "Buffer"; static constexpr const char* _type_key = "Buffer";
TVM_DECLARE_NODE_TYPE_INFO(BufferNode); TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node);
}; };
inline const BufferNode* Buffer::operator->() const { inline const BufferNode* Buffer::operator->() const {
......
...@@ -32,6 +32,13 @@ PackedFunc BuildStackVM( ...@@ -32,6 +32,13 @@ PackedFunc BuildStackVM(
const std::unordered_map<LoweredFunc, PackedFunc>& device_funcs); const std::unordered_map<LoweredFunc, PackedFunc>& device_funcs);
/*! /*!
* \brief Build a LLVM VM function, this is still beta
* \param func The LoweredFunc to be build
* \return A packed function representing the func.
*/
PackedFunc BuildLLVM(LoweredFunc func);
/*!
* \brief Build a CUDA function with NVRTC * \brief Build a CUDA function with NVRTC
* *
* \param fsplits The LoweredFuncs to be build (after SplitHostDevice) * \param fsplits The LoweredFuncs to be build (after SplitHostDevice)
......
...@@ -36,7 +36,6 @@ using Halide::Internal::make_zero; ...@@ -36,7 +36,6 @@ using Halide::Internal::make_zero;
using Halide::Internal::as_const_int; using Halide::Internal::as_const_int;
using Halide::Internal::as_const_uint; using Halide::Internal::as_const_uint;
inline Type TVMType2Type(TVMType t) { inline Type TVMType2Type(TVMType t) {
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes); return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
} }
...@@ -182,7 +181,7 @@ class IterVarNode : public Node { ...@@ -182,7 +181,7 @@ class IterVarNode : public Node {
static IterVar make(Range dom, Var var, std::string thread_tag); static IterVar make(Range dom, Var var, std::string thread_tag);
static constexpr const char* _type_key = "IterVar"; static constexpr const char* _type_key = "IterVar";
TVM_DECLARE_NODE_TYPE_INFO(IterVarNode); TVM_DECLARE_NODE_TYPE_INFO(IterVarNode, Node);
}; };
// inline implementations // inline implementations
......
...@@ -200,6 +200,8 @@ using Halide::Internal::Realize; ...@@ -200,6 +200,8 @@ using Halide::Internal::Realize;
using Halide::Internal::Block; using Halide::Internal::Block;
using Halide::Internal::IfThenElse; using Halide::Internal::IfThenElse;
using Halide::Internal::Evaluate; using Halide::Internal::Evaluate;
// ir functions
using Halide::Internal::is_const_power_of_two_integer;
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
......
...@@ -92,7 +92,7 @@ class LoweredFuncNode : public FunctionBaseNode { ...@@ -92,7 +92,7 @@ class LoweredFuncNode : public FunctionBaseNode {
} }
static constexpr const char* _type_key = "LoweredFunc"; static constexpr const char* _type_key = "LoweredFunc";
TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode); TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode, Node);
}; };
// Implementations of inline functions // Implementations of inline functions
......
...@@ -39,7 +39,7 @@ class PlaceholderOpNode : public OperationNode { ...@@ -39,7 +39,7 @@ class PlaceholderOpNode : public OperationNode {
Type dtype); Type dtype);
static constexpr const char* _type_key = "PlaceholderOp"; static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode); TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode, OperationNode);
}; };
/*! /*!
...@@ -74,7 +74,7 @@ class ComputeOpNode : public OperationNode { ...@@ -74,7 +74,7 @@ class ComputeOpNode : public OperationNode {
Expr body); Expr body);
static constexpr const char* _type_key = "ComputeOp"; static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode); TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode);
}; };
/*! /*!
...@@ -123,7 +123,7 @@ class ScanOpNode : public OperationNode { ...@@ -123,7 +123,7 @@ class ScanOpNode : public OperationNode {
Array<Tensor> state_placeholder); Array<Tensor> state_placeholder);
static constexpr const char* _type_key = "ScanOp"; static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode); TVM_DECLARE_NODE_TYPE_INFO(ScanOpNode, OperationNode);
}; };
......
...@@ -33,7 +33,7 @@ struct NodeTypeChecker { ...@@ -33,7 +33,7 @@ struct NodeTypeChecker {
// It can be turned off, but will make non strict checking. // It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI // TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType; using ContainerType = typename T::ContainerType;
return (dynamic_cast<ContainerType*>(sptr) != nullptr); return sptr->derived_from<ContainerType>();
} }
static inline void PrintName(std::ostringstream& os) { // NOLINT(*) static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType; using ContainerType = typename T::ContainerType;
......
...@@ -153,6 +153,13 @@ typedef void* TVMRetValueHandle; ...@@ -153,6 +153,13 @@ typedef void* TVMRetValueHandle;
typedef TVMArray* TVMArrayHandle; typedef TVMArray* TVMArrayHandle;
/*! /*!
* \brief Used for implementing C API function.
* Set last error message before return.
* \param msg The error message to be set.
*/
TVM_DLL void TVMAPISetLastError(const char* msg);
/*!
* \brief return str message of the last error * \brief return str message of the last error
* all function in this file will return 0 when success * all function in this file will return 0 when success
* and -1 when an error occured, * and -1 when an error occured,
...@@ -287,10 +294,10 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret, ...@@ -287,10 +294,10 @@ TVM_DLL int TVMCFuncSetReturn(TVMRetValueHandle ret,
* \param num_args Number of arguments. * \param num_args Number of arguments.
* \param ret The return value handle. * \param ret The return value handle.
* \param resource_handle The handle additional resouce handle from fron-end. * \param resource_handle The handle additional resouce handle from fron-end.
* * \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError.
* \sa TVMCFuncSetReturn * \sa TVMCFuncSetReturn
*/ */
typedef void (*TVMPackedCFunc)( typedef int (*TVMPackedCFunc)(
TVMValue* args, TVMValue* args,
int* type_codes, int* type_codes,
int num_args, int num_args,
......
...@@ -331,7 +331,7 @@ class StageNode : public Node { ...@@ -331,7 +331,7 @@ class StageNode : public Node {
} }
static constexpr const char* _type_key = "Stage"; static constexpr const char* _type_key = "Stage";
TVM_DECLARE_NODE_TYPE_INFO(StageNode); TVM_DECLARE_NODE_TYPE_INFO(StageNode, Node);
}; };
/*! \brief node container for schedule */ /*! \brief node container for schedule */
...@@ -354,7 +354,7 @@ class ScheduleNode : public Node { ...@@ -354,7 +354,7 @@ class ScheduleNode : public Node {
} }
static constexpr const char* _type_key = "Schedule"; static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode); TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode, Node);
}; };
/*! \brief node container for IterVar attr */ /*! \brief node container for IterVar attr */
...@@ -368,11 +368,14 @@ class IterVarAttrNode : public Node { ...@@ -368,11 +368,14 @@ class IterVarAttrNode : public Node {
} }
static constexpr const char* _type_key = "IterVarAttr"; static constexpr const char* _type_key = "IterVarAttr";
TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode); TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode, Node);
}; };
/*! \brief base node of iteration var */ /*! \brief base node of iteration var */
class IterVarRelationNode : public Node { class IterVarRelationNode : public Node {
public:
static constexpr const char* _type_key = "IterVarRelation";
TVM_DECLARE_BASE_NODE_INFO(IterVarRelationNode, Node);
}; };
/*! /*!
...@@ -402,7 +405,7 @@ class SplitNode : public IterVarRelationNode { ...@@ -402,7 +405,7 @@ class SplitNode : public IterVarRelationNode {
IterVar inner, Expr factor); IterVar inner, Expr factor);
static constexpr const char* _type_key = "Split"; static constexpr const char* _type_key = "Split";
TVM_DECLARE_NODE_TYPE_INFO(SplitNode); TVM_DECLARE_NODE_TYPE_INFO(SplitNode, IterVarRelationNode);
}; };
/*! /*!
...@@ -427,7 +430,7 @@ class FuseNode : public IterVarRelationNode { ...@@ -427,7 +430,7 @@ class FuseNode : public IterVarRelationNode {
IterVar outer, IterVar inner, IterVar fused); IterVar outer, IterVar inner, IterVar fused);
static constexpr const char* _type_key = "Fuse"; static constexpr const char* _type_key = "Fuse";
TVM_DECLARE_NODE_TYPE_INFO(FuseNode); TVM_DECLARE_NODE_TYPE_INFO(FuseNode, IterVarRelationNode);
}; };
/*! /*!
...@@ -450,7 +453,7 @@ class RebaseNode : public IterVarRelationNode { ...@@ -450,7 +453,7 @@ class RebaseNode : public IterVarRelationNode {
static IterVarRelation make(IterVar parent, IterVar rebased); static IterVarRelation make(IterVar parent, IterVar rebased);
static constexpr const char* _type_key = "Rebase"; static constexpr const char* _type_key = "Rebase";
TVM_DECLARE_NODE_TYPE_INFO(RebaseNode); TVM_DECLARE_NODE_TYPE_INFO(RebaseNode, IterVarRelationNode);
}; };
......
...@@ -153,7 +153,7 @@ class TensorNode : public Node { ...@@ -153,7 +153,7 @@ class TensorNode : public Node {
int value_index); int value_index);
static constexpr const char* _type_key = "Tensor"; static constexpr const char* _type_key = "Tensor";
TVM_DECLARE_NODE_TYPE_INFO(TensorNode); TVM_DECLARE_NODE_TYPE_INFO(TensorNode, Node);
}; };
/*! /*!
...@@ -167,8 +167,6 @@ class OperationNode : public FunctionBaseNode { ...@@ -167,8 +167,6 @@ class OperationNode : public FunctionBaseNode {
const std::string& func_name() const final { const std::string& func_name() const final {
return name; return name;
} }
/*! \return number of outputs of this op */
virtual int num_outputs() const = 0;
/*! \return the list of iteration variable at root */ /*! \return the list of iteration variable at root */
virtual Array<IterVar> root_iter_vars() const = 0; virtual Array<IterVar> root_iter_vars() const = 0;
/*! \return type of i-th output */ /*! \return type of i-th output */
...@@ -177,6 +175,8 @@ class OperationNode : public FunctionBaseNode { ...@@ -177,6 +175,8 @@ class OperationNode : public FunctionBaseNode {
virtual Array<Expr> output_shape(size_t i) const = 0; virtual Array<Expr> output_shape(size_t i) const = 0;
static constexpr const char* _type_key = "Operation"; static constexpr const char* _type_key = "Operation";
TVM_DECLARE_BASE_NODE_INFO(OperationNode, Node);
}; };
// Implementations of inline functions // Implementations of inline functions
......
...@@ -40,10 +40,11 @@ USE_CUDA = 1 ...@@ -40,10 +40,11 @@ USE_CUDA = 1
# whether use OpenCL during compile # whether use OpenCL during compile
USE_OPENCL = 0 USE_OPENCL = 0
# add the path to CUDA library to link and compile flag # whether build with LLVM support
# if you have already add them to environment variable, leave it as NONE # This requires llvm-config to be in your PATH
# USE_CUDA_PATH = /usr/local/cuda # Requires LLVM version >= 4.0
USE_CUDA_PATH = NONE USE_LLVM = 0
# whether use cuda runtime compiling for writing kernels in native language (i.e. Python) # add the path to CUDA library to link and compile flag
USE_NVRTC = 0 # if you have already add them to environment variable.
# CUDA_PATH = /usr/local/cuda
...@@ -56,6 +56,7 @@ def convert_to_tvm_func(pyfunc): ...@@ -56,6 +56,7 @@ def convert_to_tvm_func(pyfunc):
check_call(_LIB.TVMCFuncSetReturn(ret, values[0], ctypes.c_int(tcodes[0]))) check_call(_LIB.TVMCFuncSetReturn(ret, values[0], ctypes.c_int(tcodes[0])))
_ = temp_args _ = temp_args
_ = rv _ = rv
return 0
handle = FunctionHandle() handle = FunctionHandle()
f = TVMPackedCFunc(cfun) f = TVMPackedCFunc(cfun)
......
...@@ -96,7 +96,7 @@ class TVMByteArray(ctypes.Structure): ...@@ -96,7 +96,7 @@ class TVMByteArray(ctypes.Structure):
TVMPackedCFunc = ctypes.CFUNCTYPE( TVMPackedCFunc = ctypes.CFUNCTYPE(
None, ctypes.c_int,
ctypes.POINTER(TVMValue), ctypes.POINTER(TVMValue),
ctypes.POINTER(ctypes.c_int), ctypes.POINTER(ctypes.c_int),
ctypes.c_int, ctypes.c_int,
......
...@@ -6,3 +6,4 @@ ...@@ -6,3 +6,4 @@
- arithmetic Arithmetic expression and set simplification - arithmetic Arithmetic expression and set simplification
- pass The optimization pass on the IR structure - pass The optimization pass on the IR structure
- runtime Minimum runtime related codes. - runtime Minimum runtime related codes.
- codegen The code generator
...@@ -37,6 +37,11 @@ TVM_REGISTER_API(_codegen_BuildStackVM) ...@@ -37,6 +37,11 @@ TVM_REGISTER_API(_codegen_BuildStackVM)
std::unordered_map<LoweredFunc, PackedFunc>()); std::unordered_map<LoweredFunc, PackedFunc>());
}); });
TVM_REGISTER_API(_codegen_BuildLLVM)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = BuildLLVM(args[0]);
});
TVM_REGISTER_API(_codegen_BuildNVRTC) TVM_REGISTER_API(_codegen_BuildNVRTC)
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = BuildNVRTC(args[0], args[1]); *ret = BuildNVRTC(args[0], args[1]);
......
...@@ -20,7 +20,7 @@ enum SignType { ...@@ -20,7 +20,7 @@ enum SignType {
}; };
// internal node container of int set. // internal node container of int set.
class IntSetNode; struct IntSetNode;
/*! /*!
* \brief Integer set class, represent a set of integers in one dimension. * \brief Integer set class, represent a set of integers in one dimension.
...@@ -104,6 +104,8 @@ class IntSet : public NodeRef { ...@@ -104,6 +104,8 @@ class IntSet : public NodeRef {
* \brief Base class of all IntSet containers. * \brief Base class of all IntSet containers.
*/ */
struct IntSetNode : public Node { struct IntSetNode : public Node {
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
}; };
using ExprIntSetMap = std::unordered_map<Expr, IntSet, using ExprIntSetMap = std::unordered_map<Expr, IntSet,
......
...@@ -35,7 +35,7 @@ struct IntervalSet : public IntSetNode { ...@@ -35,7 +35,7 @@ struct IntervalSet : public IntSetNode {
} }
static constexpr const char* _type_key = "IntervalSet"; static constexpr const char* _type_key = "IntervalSet";
TVM_DECLARE_NODE_TYPE_INFO(IntervalSet); TVM_DECLARE_NODE_TYPE_INFO(IntervalSet, IntSetNode);
}; };
/*! /*!
...@@ -51,7 +51,7 @@ struct StrideSet : public IntSetNode { ...@@ -51,7 +51,7 @@ struct StrideSet : public IntSetNode {
Array<Expr> strides; Array<Expr> strides;
static constexpr const char* _type_key = "StrideSet"; static constexpr const char* _type_key = "StrideSet";
TVM_DECLARE_NODE_TYPE_INFO(StrideSet); TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode);
}; };
} // namespace arith } // namespace arith
......
...@@ -272,9 +272,6 @@ inline void PushBinary(StackVM::OpCode op_int64, ...@@ -272,9 +272,6 @@ inline void PushBinary(StackVM::OpCode op_int64,
} }
} }
inline void PushCast(Type dst, inline void PushCast(Type dst,
Type src, Type src,
CodeGenStackVM* p) { CodeGenStackVM* p) {
...@@ -496,7 +493,5 @@ TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable) ...@@ -496,7 +493,5 @@ TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
.set_dispatch<Call>([](const Call *op, CodeGenStackVM* p) { .set_dispatch<Call>([](const Call *op, CodeGenStackVM* p) {
p->Push_(op); p->Push_(op);
}); });
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file codegen_llvm.h
* \brief Common base class for generating into LLVM IR
*/
#ifndef TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#define TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#ifdef TVM_LLVM_VERSION
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/codegen.h>
#include <memory>
#include <vector>
#include <string>
#include "./llvm_common.h"
namespace tvm {
namespace codegen {
using namespace ir;
/*!
* \brief A base class to generate a LLVM.
*/
class CodeGenLLVM : public IRVisitor {
public:
/*!
* \brief Initialize the code generator with given context
* \param module_name The name of the module.
* \param ctx The context.
*/
void Init(const std::string& module_name, llvm::LLVMContext* ctx);
/*!
* \brief Compile and add function f to the current module.
* \param f The function to be added.
*/
void AddFunction(const LoweredFunc& f);
/*!
* \brief Finish current pass of codegen, get the module.
* \return the created module.
*/
std::unique_ptr<llvm::Module> Finish();
/*!
* \brief Create Value for expression e
* \param e The expression to be created value for.
* \return created value.
*/
llvm::Value* MakeValue(const Expr& e) {
value_ = nullptr;
this->Visit(e);
CHECK(value_ != nullptr);
return value_;
}
// Short hande code to get a constant int 32
llvm::Constant* ConstInt32(unsigned value) const {
return llvm::ConstantInt::get(t_int32_, value);
}
// override codegen
void Visit_(const Variable* op) final;
void Visit_(const Cast* op) final;
void Visit_(const IntImm* op) final;
void Visit_(const UIntImm* op) final;
void Visit_(const FloatImm* op) final;
void Visit_(const StringImm* op) final;
void Visit_(const Add* op) final;
void Visit_(const Sub* op) final;
void Visit_(const Mul* op) final;
void Visit_(const Div* op) final;
void Visit_(const Mod* op) final;
void Visit_(const Min* op) final;
void Visit_(const Max* op) final;
void Visit_(const LT* op) final;
void Visit_(const LE* op) final;
void Visit_(const GT* op) final;
void Visit_(const GE* op) final;
void Visit_(const EQ* op) final;
void Visit_(const NE* op) final;
void Visit_(const And* op) final;
void Visit_(const Or* op) final;
void Visit_(const Not* op) final;
void Visit_(const Select* op) final;
void Visit_(const Let* op) final;
void Visit_(const Load* op) final;
void Visit_(const Call* op) final;
void Visit_(const Ramp* op) final;
void Visit_(const Broadcast* op) final;
// stmt
void Visit_(const Store* op) final;
void Visit_(const For* op) final;
void Visit_(const IfThenElse* op) final;
void Visit_(const Allocate* op) final;
void Visit_(const AttrStmt* op) override;
void Visit_(const AssertStmt* op) final;
void Visit_(const LetStmt* op) final;
// create intrinstic given call
virtual llvm::Value* CreateIntrinstic(const Call* op);
// create extern function call
virtual llvm::Value* CreateCallExtern(const Call* op);
// create call into tvm packed function.
virtual llvm::Value* CreateCallPacked(const Call* op);
protected:
/*!
* \param t The original type.
* \return LLVM type of t
*/
llvm::Type* LLVMType(const Type& t) const;
// do a scalarize call with f
llvm::Value* CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
// apply optimization on the module.
virtual void Optimize();
// The IRBuilder.
using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
// The current function
llvm::Function* function_;
// Internal builder
std::unique_ptr<IRBuilder> builder_;
// The module to be returned;
std::unique_ptr<llvm::Module> module_;
// Internal metabuilder
std::unique_ptr<llvm::MDBuilder> md_builder_;
// llvm context
llvm::LLVMContext* ctx_{nullptr};
// helpful data types
llvm::Type* t_void_{nullptr};
llvm::Type* t_void_p_{nullptr};
llvm::Type* t_int_{nullptr};
llvm::Type* t_char_{nullptr};
llvm::Type* t_int8_{nullptr};
llvm::Type* t_int16_{nullptr};
llvm::Type* t_int32_{nullptr};
llvm::Type* t_int64_{nullptr};
llvm::Type* t_float64_{nullptr};
// branch
llvm::MDNode* md_very_likely_branch_{nullptr};
llvm::MDNode* md_tbaa_root_{nullptr};
// TVM related data types
llvm::Type* t_tvm_index_{nullptr};
llvm::Type* t_tvm_func_handle_{nullptr};
llvm::StructType* t_tvm_context_{nullptr};
llvm::StructType* t_tvm_type_{nullptr};
llvm::StructType* t_tvm_array_{nullptr};
llvm::StructType* t_tvm_value_{nullptr};
// tvm api functions
llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_func_get_global_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr};
// The acting body
llvm::BasicBlock* block_{nullptr};
// Last value returned codegen call.
llvm::Value* value_{nullptr};
private:
// comparison op
llvm::Value* GetVarValue(const Variable* v) const;
llvm::Value* CreateLT(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateLE(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateGT(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateGE(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateAdd(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateSub(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateMul(Type t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
llvm::Value* GetConstString(const std::string& str);
llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
llvm::Value* GetPackedFuncHandle(const std::string& str, bool global);
// Check if the call to packed function is successful
// if not directly finalize function and pass on return code.
// return the end block after the check
llvm::BasicBlock* CheckPackedCallSuccess(llvm::Value* retcode);
// add alias information.
void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index);
// The definition of local variable.
std::unordered_map<const Variable*, llvm::Value*> var_map_;
// global strings
std::unordered_map<std::string, llvm::Constant*> str_map_;
// global to packed function handle
std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;
};
} // namespace codegen
} // namespace tvm
#endif // LLVM_VERSION
#endif // TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
/*!
* Copyright (c) 2017 by Contributors
* \file llvm_common.cc
*/
#ifdef TVM_LLVM_VERSION
#include <tvm/base.h>
#include <mutex>
#include "./llvm_common.h"
namespace tvm {
namespace codegen {
struct LLVMEnv {
std::mutex mu;
bool native_initialized{false};
static LLVMEnv* Global() {
static LLVMEnv inst;
return &inst;
}
};
void InitializeLLVM() {
LLVMEnv* e = LLVMEnv::Global();
if (!e->native_initialized) {
std::lock_guard<std::mutex>(e->mu);
if (!e->native_initialized) {
e->native_initialized = true;
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();
}
}
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
/*!
* Copyright (c) 2017 by Contributors
* \file llvm_common.h
* \brief Common utilities for llvm initialization.
*/
#ifndef TVM_CODEGEN_LLVM_LLVM_COMMON_H_
#define TVM_CODEGEN_LLVM_LLVM_COMMON_H_
#ifdef TVM_LLVM_VERSION
#include <llvm/ExecutionEngine/MCJIT.h>
#include <llvm/IR/Value.h>
#include <llvm/IR/Argument.h>
#include <llvm/IR/BasicBlock.h>
#include <llvm/IR/Constants.h>
#include <llvm/IR/DerivedTypes.h>
#include <llvm/IR/Function.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/Type.h>
#include <llvm/IR/Intrinsics.h>
#include <llvm/IR/MDBuilder.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
#include <llvm/Transforms/IPO.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/TargetSelect.h>
extern "C" {
// Function signature for LLVM generated packed function.
typedef int (*LLVMPackedCFunc)(void* args,
int* type_codes,
int num_args);
} // extern "C"
namespace tvm {
namespace codegen {
/*!
* \brief Initialize LLVM on this process,
* can be called multiple times.
*/
void InitializeLLVM();
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
#endif // TVM_CODEGEN_LLVM_LLVM_COMMON_H_
/*!
* Copyright (c) 2017 by Contributors
* \file llvm_exec_engine.cc
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/codegen.h>
#include "./llvm_common.h"
#include "./codegen_llvm.h"
namespace tvm {
namespace codegen {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
#ifdef TVM_LLVM_VERSION
// Environment to keep jit resources alive.
struct LLVMJITEnv {
std::shared_ptr<llvm::LLVMContext> ctx;
llvm::ExecutionEngine* ee{nullptr};
// constructor
LLVMJITEnv(std::shared_ptr<llvm::LLVMContext> ctx,
llvm::ExecutionEngine* ee)
: ctx(ctx), ee(ee) {
}
// destructor
~LLVMJITEnv() {
if (ee != nullptr) {
ee->runStaticConstructorsDestructors(true);
delete ee;
}
}
};
PackedFunc JITCompile(std::unique_ptr<llvm::Module> module,
std::shared_ptr<llvm::LLVMContext> ctx,
const std::string& func_name) {
llvm::EngineBuilder builder(std::move(module));
builder.setEngineKind(llvm::EngineKind::JIT);
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
std::shared_ptr<LLVMJITEnv> env = std::make_shared<LLVMJITEnv>(
ctx, builder.create());
CHECK(env->ee != nullptr);
auto* faddr = reinterpret_cast<LLVMPackedCFunc>(
env->ee->getFunctionAddress(func_name));
env->ee->runStaticConstructorsDestructors(false);
return PackedFunc([env, faddr](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)(
(void*)args.values, // NOLINT(*)
(int*)args.type_codes, // NOLINT(*)
args.num_args);
CHECK(ret == 0) << TVMGetLastError();
});
}
PackedFunc BuildLLVM(LoweredFunc func) {
InitializeLLVM();
// use one context per function.
std::shared_ptr<llvm::LLVMContext> ctx =
std::make_shared<llvm::LLVMContext>();
CodeGenLLVM cg;
cg.Init(func->name, ctx.get());
cg.AddFunction(func);
std::unique_ptr<llvm::Module> m = cg.Finish();
return JITCompile(std::move(m), ctx, func->name);
}
#else
PackedFunc BuildLLVM(LoweredFunc func) {
LOG(FATAL) << "LLVM is not enabled";
return PackedFunc();
}
#endif // TVM_LLVM_VERSION
} // namespace codegen
} // namespace tvm
...@@ -21,8 +21,6 @@ ...@@ -21,8 +21,6 @@
*/ */
#define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*) #define API_END_HANDLE_ERROR(Finalize) } catch(std::runtime_error &_except_) { Finalize; return TVMAPIHandleException(_except_); } return 0; // NOLINT(*)
void TVMAPISetLastError(const char* msg);
/*! /*!
* \brief handle exception throwed out * \brief handle exception throwed out
* \param e the exception * \param e the exception
......
...@@ -274,8 +274,9 @@ Stmt MakeLoop(const Stage& s, ...@@ -274,8 +274,9 @@ Stmt MakeLoop(const Stage& s,
bound_state[iv] = false; bound_state[iv] = false;
} }
PassUpBoundCheck(s, dom_map, &bound_state); PassUpBoundCheck(s, dom_map, &bound_state);
auto nest = MakeLoopNest(s, dom_map, 0, false, auto nest = MakeLoopNest(
bound_state, {}, &value_map); s, dom_map, 0, false,
bound_state, {{}}, &value_map);
provide = Substitute(provide, value_map); provide = Substitute(provide, value_map);
if (init.defined()) { if (init.defined()) {
......
...@@ -2,7 +2,6 @@ import tvm ...@@ -2,7 +2,6 @@ import tvm
import numpy as np import numpy as np
def test_add_pipeline(): def test_add_pipeline():
"""Not yet working, mock design"""
n = tvm.Var('n') n = tvm.Var('n')
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
......
...@@ -6,6 +6,16 @@ def tvm_call_global(*args): ...@@ -6,6 +6,16 @@ def tvm_call_global(*args):
return tvm.make.Call("int32", "tvm_call_global", args, 4, None, 0) return tvm.make.Call("int32", "tvm_call_global", args, 4, None, 0)
def run_jit(fapi, check):
for target in ["stackvm"]:
if target == "llvm":
f = tvm.codegen.BuildLLVM(fapi)
else:
f = tvm.codegen.BuildStackVM(fapi)
check(f)
def test_stack_vm_basic(): def test_stack_vm_basic():
a = tvm.nd.array(np.zeros(10, dtype='float32')) a = tvm.nd.array(np.zeros(10, dtype='float32'))
@tvm.register_func @tvm.register_func
...@@ -17,8 +27,7 @@ def test_stack_vm_basic(): ...@@ -17,8 +27,7 @@ def test_stack_vm_basic():
Ab = tvm.Buffer((n, ), tvm.float32) Ab = tvm.Buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm_call_global("tvm_call_back_get_shape", Ab.shape[0])) stmt = tvm.make.Evaluate(tvm_call_global("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 1) fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi) run_jit(fapi, lambda f: f(a))
f(a)
@tvm.register_func @tvm.register_func
...@@ -42,8 +51,10 @@ def test_stack_vm_loop(): ...@@ -42,8 +51,10 @@ def test_stack_vm_loop():
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 1) fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi) f = tvm.codegen.BuildStackVM(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) def check(f):
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0])) f(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
run_jit(fapi, check)
def test_stack_vm_cond(): def test_stack_vm_cond():
...@@ -61,15 +72,46 @@ def test_stack_vm_cond(): ...@@ -61,15 +72,46 @@ def test_stack_vm_cond():
tvm.make.Store(Ab.data, tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 2, i + 1))) tvm.make.Load(dtype, Ab.data, i) + 2, i + 1)))
fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 1) fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi) def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
y = np.arange(a.shape[0]) * 2 y = np.arange(a.shape[0]) * 2
y[5:] -= 1 y[5:] -= 1
np.testing.assert_equal(a.asnumpy(), y) np.testing.assert_equal(a.asnumpy(), y)
run_jit(fapi, check)
def test_llvm_add_pipeline():
n = tvm.Var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.Schedule(C.op)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 3)
def check_llvm():
# build and invoke the kernel.
f = tvm.codegen.BuildLLVM(fapi)
ctx = tvm.cpu(0)
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx)
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
#check_llvm()
if __name__ == "__main__": if __name__ == "__main__":
test_stack_vm_cond() test_stack_vm_cond()
test_stack_vm_loop()
test_stack_vm_basic() test_stack_vm_basic()
test_stack_vm_loop()
test_llvm_add_pipeline()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment