Unverified Commit 841725cc by Tianqi Chen Committed by GitHub

[TIR][TARGET] Refactor Target codegen to use IRModule and PrimFunc. (#5107)

As part of the unified IR refactor.
This PR refactors the target codegen to use IRModule containing tir::PrimFuncs.

In order to break the refactor into several steps without breaking the codebase,
we built an conversion pass to convert Array<LoweredFunc> into IRModule.

The follow-up refactors will gradually move the passes covered by IRModule up
until we cover all the passes. Then we can remove the additional redundant
concepts such as LoweredFunc.
parent 86079479
......@@ -32,6 +32,7 @@
#include <string>
#include <algorithm>
#include <limits>
#include <type_traits>
namespace tvm {
......@@ -308,6 +309,17 @@ class Integer : public IntImm {
*/
Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*)
/*!
* \brief Constructor from enum
* \tparam Enum The enum type.
* \param value The enum value.
*/
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
explicit Integer(ENum value) : Integer(static_cast<int>(value)) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
}
/*!
* \brief Assign an expression to integer.
* \param other another expression.
*/
......
......@@ -213,6 +213,27 @@ constexpr const char* kCallingConv = "calling_conv";
* \sa tvm::Target
*/
constexpr const char* kTarget = "target";
/*!
* \brief Global linker symbol of the function in generated code.
*
* This option forces the code generator to name the
* function with the given.
*
* For example, we could set a global_symbol of a function
* early to make sure that we can always refer to it by
* the symbol name in the generated DLL.
*
* We should not set the attribute for local functions,
* so that the compiler can freely rename them.
*
* A unique global symbol will be automatically assigned
* to each function in the module before the target code
* generation phase.
*
* Type: String
*/
constexpr const char* kGlobalSymbol = "global_symbol";
} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
......@@ -114,7 +114,8 @@ class PrimTypeNode : public TypeNode {
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
};
/*!
/*
* \brief Managed reference to PrimTypeNode.
* \sa PrimTypeNode
*/
......@@ -124,11 +125,53 @@ class PrimType : public Type {
* \brief Constructor
* \param dtype The corresponding dtype.
*/
TVM_DLL PrimType(runtime::DataType dtype);
TVM_DLL explicit PrimType(runtime::DataType dtype);
TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode);
};
/*!
* \brief Low-level raw pointer type.
*
* PointerType represents type hints in the TIR to be
* passed to the final code generator.
*
* PointerType should not occur in the high-level analysis.
*
* \sa PointerType
*/
class PointerTypeNode : public TypeNode {
public:
/*!
* \brief The type of the element which the pointer points to.
*/
Type element_type;
void VisitAttrs(AttrVisitor* v) {
v->Visit("element_type", &element_type);
}
static constexpr const char* _type_key = "PointerType";
TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode);
};
/*
* \brief Managed reference to PointerTypeNode.
* \sa PointerTypeNode
*/
class PointerType : public Type {
public:
/*!
* \brief Constructor
* \param element_type The type of the element which the pointer points to.
*/
TVM_DLL explicit PointerType(Type element_type);
TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode);
};
/*! \brief Possible kinds of TypeVars. */
enum TypeKind : int {
kType = 0,
......@@ -284,6 +327,15 @@ inline Type VoidType() {
}
/*!
* \brief Check whether the tyep represents void.
* \return The check result.
*/
inline bool IsVoidType(const Type& type) {
auto* n = type.as<TupleTypeNode>();
return n && n->fields.size() == 0;
}
/*!
* \brief Potential Constraints in a function.
* \sa TypeConstraint
*/
......
......@@ -55,22 +55,27 @@ namespace tir {
*/
class VarNode : public PrimExprNode {
public:
/*! \brief constructor */
VarNode() {}
VarNode(DataType dtype, std::string name_hint);
/*!
* \brief The hint to the variable name.
* \note Each variable is uniquely identified by its address.
*/
std::string name_hint;
/*!
* \brief type annotaion of the variable.
*
* It is an optional field that provides a refined type of the variable than dtype.
*
* \sa tvm/ir/type.h for discussion of relations between runtime::DataType and Type.
*/
Type type_annotation;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("name", &name_hint);
v->Visit("type_annotation", &type_annotation);
}
static constexpr const char* _type_key = "Variable";
static constexpr const char* _type_key = "tir.Var";
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
};
......@@ -78,20 +83,25 @@ class VarNode : public PrimExprNode {
class Var : public PrimExpr {
public:
explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
/*! \brief constructor
/*!
* \brief Constructor
* \param name_hint variable name
* \param t data type
* \param dtype data type
*/
TVM_DLL explicit Var(std::string name_hint = "v",
DataType t = DataType::Int(32));
DataType dtype = DataType::Int(32));
/*!
* \brief Constructor which provides a more detailed type annotation.
* \param name_hint variable name.
* \param type_annotation The type annotation.
*/
TVM_DLL explicit Var(std::string name_hint, Type type_annotation);
/*!
* \brief Make a new copy of var with same type, append suffix
* \param suffix The suffix to be appended.
* \return the new Var copy
*/
Var copy_with_suffix(const std::string& suffix) const {
return Var((*this)->name_hint + suffix, (*this)->dtype);
}
TVM_DLL Var copy_with_suffix(const std::string& suffix) const;
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
......@@ -116,15 +126,7 @@ class Var : public PrimExpr {
*/
class SizeVarNode : public VarNode {
public:
/*! \brief constructor */
SizeVarNode() {}
/*! \brief constructor
* \param dtype data type
* \param name_hint variable name
*/
SizeVarNode(DataType dtype, std::string name_hint);
static constexpr const char* _type_key = "SizeVar";
static constexpr const char* _type_key = "tir.SizeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode);
};
......@@ -132,12 +134,13 @@ class SizeVarNode : public VarNode {
class SizeVar : public Var {
public:
explicit SizeVar(ObjectPtr<Object> n) : Var(n) {}
/*! \brief constructor
/*!
* \brief constructor
* \param name_hint variable name
* \param t data type
*/
TVM_DLL explicit SizeVar(std::string name_hint = "s",
DataType t = DataType::Int(32));
DataType t = DataType::Int(32));
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
......
......@@ -171,6 +171,16 @@ constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis";
* Type: Integer
*/
constexpr const char* kNoAlias = "tir.noalias";
/*!
* \brief Mark the function as the entry function of
* the final generated runtime module.
*
* Type: Integer
*
* \note There can only be one entry function per module.
*/
constexpr const char* kIsEntryFunc = "tir.is_entry_func";
} // namespace attr
} // namespace tir
} // namespace tvm
......
......@@ -30,6 +30,7 @@
#include <tvm/te/schedule.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/function.h>
#include <tvm/tir/lowered_func.h>
#include <unordered_map>
......@@ -515,6 +516,19 @@ LoweredFunc CombineContextCall(LoweredFunc f);
*/
LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
* to avoid pointer casting in backend when possible.
*
* \note implemeneted in storage_rewrite.cc
* \param f The function to be trasnformed
* \return Transformed function.
*/
PrimFunc PointerValueTypeRewrite(PrimFunc f);
/*!
* \brief Lower attached storage access information on device.
* Do this pass after all storage access analysis finish.
......
......@@ -52,11 +52,24 @@ namespace tvm {
* This function could return a more refined type than
* the runtime type provided by expr->dtype
*
* \param expr The input parameter.
* \return The result type.
*
* \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
*/
TVM_DLL Type GetType(const PrimExpr& expr);
/*!
* \brief Get the implied DataType for storing values with type during runtime.
*
* \param type The input type.
* \return The result runtime::DataType.
*
* \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
*/
TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type);
/*!
* Query the maximum possible value of dtype.
* \param dtype The data type.
* \return the maximum possible value in this format.
......
......@@ -17,7 +17,7 @@
# pylint: disable=unused-import
"""Common data structures across all IR variants."""
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
from .type import Type, TypeKind, PrimType, TypeVar, GlobalTypeVar, TupleType
from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
......
......@@ -72,7 +72,15 @@ def create_updater_06_to_07():
return item
return _convert
def _update_tir_var(new_name):
def _convert(item, _):
item["type_key"] = new_name
item["attrs"]["type_annotation"] = "0"
return item
return _convert
node_map = {
# Base IR
"relay.TypeVar": _ftype_var,
"relay.GlobalTypeVar": _ftype_var,
"relay.Type": _rename("Type"),
......@@ -91,6 +99,9 @@ def create_updater_06_to_07():
"relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequantial": _rename("transform.Sequantial"),
# TIR
"Variable": _update_tir_var("tir.Var"),
"SizeVar": _update_tir_var("tir.SizeVar"),
}
return create_updater(node_map, "0.6", "0.7")
......
......@@ -46,6 +46,7 @@ class TypeKind(IntEnum):
TypeData = 6
@tvm._ffi.register_object("PrimType")
class PrimType(Type):
"""Primitive data type in the low level IR
......@@ -59,6 +60,20 @@ class PrimType(Type):
_ffi_api.PrimType, dtype)
@tvm._ffi.register_object("PointerType")
class PointerType(Type):
"""PointerType used in the low-level TIR.
Parameters
----------
element_type : tvm.ir.Type
The type of pointer's element.
"""
def __init__(self, element_type):
self.__init_handle_by_constructor__(
_ffi_api.PointerType, element_type)
@tvm._ffi.register_object("TypeVar")
class TypeVar(Type):
"""Type parameter in functions.
......
......@@ -288,7 +288,7 @@ class CmpExpr(PrimExprWithOp):
class LogicalExpr(PrimExprWithOp):
pass
@tvm._ffi.register_object("Variable")
@tvm._ffi.register_object("tir.Var")
class Var(PrimExprWithOp):
"""Symbolic variable.
......@@ -297,7 +297,7 @@ class Var(PrimExprWithOp):
name : str
The name
dtype : str
dtype : Union[str, tvm.irType]
The data type
"""
def __init__(self, name, dtype):
......@@ -305,7 +305,7 @@ class Var(PrimExprWithOp):
_ffi_api.Var, name, dtype)
@tvm._ffi.register_object
@tvm._ffi.register_object("tir.SizeVar")
class SizeVar(Var):
"""Symbolic variable to represent a tensor index size
which is greater or equal to zero.
......
......@@ -68,7 +68,7 @@ class FuncTouchedDomain final : public StmtExprVisitor {
/* TODO: Thread extent unitest not generated.*/
void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
if (op->attr_key == tir::attr::thread_extent) {
const IterVarNode* thread_axis = op->node.as<IterVarNode>();
CHECK(thread_axis);
const VarNode* var = thread_axis->var.get();
......
......@@ -92,8 +92,8 @@ VisitStmt_(const IfThenElseNode* op) {
Stmt IRMutatorWithAnalyzer::
VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
if (op->attr_key == tir::attr::thread_extent ||
op->attr_key == tir::attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
analyzer_->Bind(iv->var,
......
......@@ -40,7 +40,7 @@ using runtime::PackedFunc;
using tir::LoweredFunc;
bool LLVMEnabled() {
const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.build_llvm");
const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm");
return pf != nullptr;
}
......
......@@ -45,6 +45,27 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});
PointerType::PointerType(Type element_type) {
ObjectPtr<PointerTypeNode> n = make_object<PointerTypeNode>();
n->element_type = std::move(element_type);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(PointerTypeNode);
TVM_REGISTER_GLOBAL("ir.PointerType")
.set_body_typed([](Type element_type) {
return PointerType(element_type);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PointerTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PointerTypeNode*>(ref.get());
p->Print(node->element_type);
p->stream << '*';
});
TypeVar::TypeVar(std::string name, TypeKind kind) {
ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
n->name_hint = std::move(name);
......
......@@ -139,7 +139,7 @@ bool RuntimeEnabled(const std::string& target) {
} else if (target == "vulkan") {
f_name = "device_api.vulkan";
} else if (target == "stackvm") {
f_name = "codegen.build_stackvm";
f_name = "target.build.stackvm";
} else if (target == "rpc") {
f_name = "device_api.rpc";
} else if (target == "micro_dev") {
......
......@@ -26,6 +26,9 @@
#include <tvm/target/codegen.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/ir/module.h>
#include <tvm/tir/function.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/lowered_func.h>
......@@ -51,6 +54,31 @@ ExtractFuncInfo(const Array<tir::LoweredFunc>& funcs) {
}
return fmap;
}
inline std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule& mod) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<tir::PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<tir::PrimFunc>(kv.second);
runtime::FunctionInfo info;
for (size_t i = 0; i < f->params.size(); ++i) {
info.arg_types.push_back(f->params[i].dtype());
}
auto thread_axis = f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis);
if (thread_axis.defined()) {
for (size_t i = 0; i < thread_axis.size(); ++i) {
info.thread_axis_tags.push_back(thread_axis[i]->thread_tag);
}
}
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol)] = info;
}
return fmap;
}
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_BUILD_COMMON_H_
......@@ -23,7 +23,12 @@
*/
#include <tvm/target/codegen.h>
#include <tvm/target/target.h>
#include <tvm/ir/module.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/function.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/c_runtime_api.h>
......@@ -37,6 +42,63 @@
namespace tvm {
namespace codegen {
// The new build function.
// adapt the old function to the new one
runtime::Module BuildForIRModule(const IRModule& module,
const Target& target) {
std::string build_f_name = "target.build." + target->target_name;
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
CHECK(bf != nullptr)
<< "target.build." << target << " is not enabled";
return (*bf)(module, target->str());
}
// convert legacy LoweredFunc to PrimFunc.
tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) {
// remap args to attach type annotations.
Array<tir::Var> args;
Map<tir::Var, PrimExpr> remap_vars;
for (auto var : from->args) {
if (from->handle_data_type.count(var)) {
tir::Var new_var(var->name_hint,
PointerType(PrimType(var->dtype)));
args.push_back(new_var);
remap_vars.Set(var, new_var);
} else {
args.push_back(var);
}
}
tir::PrimFunc func(args, Substitute(from->body, remap_vars));
func = WithAttr(std::move(func), attr::kGlobalSymbol, runtime::String(from->name));
func = WithAttr(std::move(func), tir::attr::kDeviceThreadAxis, from->thread_axis);
if (from->func_type == tir::LoweredFuncType::kDeviceFunc) {
func = WithAttr(std::move(func),
attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch));
}
if (from->is_restricted) {
func = WithAttr(std::move(func), tir::attr::kNoAlias, Integer(1));
}
return func;
}
IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs) {
Map<GlobalVar, BaseFunc> functions;
for (size_t i = 0; i < funcs.size(); ++i) {
auto f = funcs[i];
tir::PrimFunc pf = ToPrimFunc(f);
if (i == 0) {
pf = WithAttr(std::move(pf), tir::attr::kIsEntryFunc, Integer(1));
}
functions.Set(GlobalVar(f->name), pf);
}
return IRModule(functions);
}
runtime::Module Build(const Array<tir::LoweredFunc>& funcs,
const std::string& target) {
std::string mode = target;
......@@ -51,15 +113,10 @@ runtime::Module Build(const Array<tir::LoweredFunc>& funcs,
transformed_funcs.push_back(func);
}
}
std::string build_f_name = "codegen.build_" + mode;
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
CHECK(bf != nullptr)
<< "Target " << target << " is not enabled";
runtime::Module m = transformed_funcs.empty() ?
(*bf)(funcs, target) :
(*bf)(transformed_funcs, target);
return m;
return BuildForIRModule(
transformed_funcs.size() != 0 ? ToIRModule(transformed_funcs) : ToIRModule(funcs),
Target::Create(target));
}
/*! \brief Helper class to serialize module */
......
......@@ -59,7 +59,7 @@ static inline int DetectROCMmaxThreadsPerBlock() {
// AMDGPU code generator.
class CodeGenAMDGPU : public CodeGenLLVM {
public:
void AddFunction(const LoweredFunc& f) final {
void AddFunction(const PrimFunc& f) final {
// add function as void return value
CodeGenLLVM::AddFunctionInternal(f, true);
function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
......@@ -91,7 +91,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
return builder_->CreateAlloca(
LLVMType(op->dtype), ConstInt32(constant_size));
DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
});
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100
......@@ -106,7 +106,8 @@ class CodeGenAMDGPU : public CodeGenLLVM {
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
llvm::Type* type = llvm::ArrayType::get(LLVMType(op->dtype), constant_size);
llvm::Type* type = llvm::ArrayType::get(
DTypeToLLVMType(op->dtype), constant_size);
// Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
......@@ -120,7 +121,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
}
}
buf = builder_->CreatePointerCast(
buf, LLVMType(op->dtype)->getPointerTo(
buf, DTypeToLLVMType(op->dtype)->getPointerTo(
buf->getType()->getPointerAddressSpace()));
CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
......@@ -170,7 +171,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
// Additional optimization hook to tweak the builder.
}
unsigned GetGlobalAddressSpace() {
unsigned GetGlobalAddressSpace() const final {
return 1;
}
......@@ -205,7 +206,7 @@ inline int DetectROCMComputeVersion(const std::string& target) {
return 900;
}
runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
runtime::Module BuildAMDGPU(IRModule mod, std::string target) {
#if TVM_LLVM_VERSION < 90
LOG(FATAL) << "AMDGPU backend requires at least LLVM 9";
// Lower versions will crash when loading the bitcode, see
......@@ -222,8 +223,13 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
cg->Init(funcs[0]->name, tm.get(), ctx.get(), false, false);
for (LoweredFunc f : funcs) {
cg->Init("TVMAMDGPUModule", tm.get(), ctx.get(), false, false);
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<PrimFunc>(kv.second);
cg->AddFunction(f);
}
......@@ -306,10 +312,10 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
std::string hsaco = (*f)(arr);
std::string ll(data_ll.begin(), data_ll.end());
return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll, assembly);
return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(mod), ll, assembly);
}
TVM_REGISTER_GLOBAL("codegen.build_rocm")
TVM_REGISTER_GLOBAL("target.build.rocm")
.set_body_typed(BuildAMDGPU);
} // namespace codegen
......
......@@ -122,11 +122,15 @@ void CodeGenCPU::Init(const std::string& module_name,
this->InitGlobalContext(dynamic_lookup);
}
void CodeGenCPU::AddFunction(const LoweredFunc& f) {
void CodeGenCPU::AddFunction(const PrimFunc& f) {
CodeGenLLVM::AddFunction(f);
if (f_tvm_register_system_symbol_ != nullptr) {
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
export_system_symbols_.emplace_back(
std::make_pair(f->name, builder_->CreatePointerCast(function_, t_void_p_)));
std::make_pair(global_symbol.operator std::string(),
builder_->CreatePointerCast(function_, t_void_p_)));
}
AddDebugInformation(function_);
}
......@@ -328,7 +332,7 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) {
arg_types.push_back(v->getType());
}
llvm::FunctionType* ftype = llvm::FunctionType::get(
LLVMType(op->dtype), arg_types, false);
GetLLVMType(GetRef<PrimExpr>(op)), arg_types, false);
// Check if it is available in global function table as injected function.
auto it = gv_func_map_.find(op->name);
if (it != gv_func_map_.end()) {
......@@ -693,8 +697,8 @@ CodeGenCPU::MakeCallPacked(const Array<PrimExpr> &args, llvm::Value **rvalue,
ret_value, *ret_tcode}));
DataType r_api_type = tir::APIType(r_type);
*rvalue = builder_->CreateAlignedLoad(
builder_->CreatePointerCast(ret_value,
LLVMType(r_api_type)->getPointerTo()),
builder_->CreatePointerCast(
ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()),
8);
*rvalue = CreateCast(r_api_type, r_type, *rvalue);
return end_block;
......@@ -873,7 +877,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
this->CreateStaticInit(op->value.as<StringImmNode>()->value, op->body);
} else if (op->attr_key == tir::attr::compute_scope) {
this->CreateComputeScope(op);
} else if (attr::IsPragmaKey(op->attr_key)) {
} else if (tir::attr::IsPragmaKey(op->attr_key)) {
if (op->attr_key == "pragma_parallel_stride_pattern") {
CHECK(parallel_env_.penv != nullptr)
<< "Pragma parallel_stride_pattern only valid in parallel launch";
......
......@@ -42,7 +42,7 @@ class CodeGenCPU : public CodeGenLLVM {
llvm::LLVMContext* ctx,
bool system_lib,
bool dynamic_lookup) override;
void AddFunction(const LoweredFunc& f) override;
void AddFunction(const PrimFunc& f) override;
void AddMainFunction(const std::string& entry_func_name) override;
std::unique_ptr<llvm::Module> Finish() override;
void VisitStmt_(const AssertStmtNode* op) override;
......
......@@ -24,6 +24,7 @@
// Part of the code are adapted from Halide's CodeGen_LLVM
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/tir/op.h>
#include <algorithm>
......@@ -94,7 +95,7 @@ void CodeGenLLVM::InitTarget(llvm::TargetMachine* tm) {
}
}
void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
void CodeGenLLVM::AddFunction(const PrimFunc& f) {
this->AddFunctionInternal(f, false);
}
......@@ -107,41 +108,43 @@ void CodeGenLLVM::InitFuncState() {
}
void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) {
this->InitFuncState();
std::vector<llvm::Type*> arg_types;
is_restricted_ = f->is_restricted;
for (Var arg : f->args) {
DataType t = arg.dtype();
if (t.is_handle()) {
auto it = f->handle_data_type.find(arg);
if (it != f->handle_data_type.end()) {
arg_types.push_back(LLVMType((*it).second.dtype())
->getPointerTo(GetGlobalAddressSpace()));
} else {
arg_types.push_back(t_int8_->getPointerTo(GetGlobalAddressSpace()));
}
if (!is_restricted_) {
alias_var_set_.insert(arg.get());
}
} else {
arg_types.push_back(LLVMType(arg.dtype()));
CHECK_EQ(f->buffer_map.size(), 0U)
<< "Cannot codegen function with buffer_map, please lower them first";
std::vector<llvm::Type*> param_types;
is_restricted_ = f->HasNonzeroAttr(tir::attr::kNoAlias);
for (Var param : f->params) {
param_types.push_back(GetLLVMType(param));
if (!is_restricted_ && param.dtype().is_handle()) {
alias_var_set_.insert(param.get());
}
}
// TODO(tvm-team):
// Update the function type to respect the ret_type field of f.
// Once we allow more flexibility in the PrimFunc.
llvm::FunctionType* ftype = llvm::FunctionType::get(
ret_void ? t_void_ : t_int_, arg_types, false);
CHECK(module_->getFunction(f->name) == nullptr)
<< "Function " << f->name << " already exist in module";
ret_void ? t_void_ : t_int_, param_types, false);
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
CHECK(module_->getFunction(static_cast<std::string>(global_symbol)) == nullptr)
<< "Function " << global_symbol << " already exist in module";
function_ = llvm::Function::Create(
ftype, llvm::Function::ExternalLinkage,
f->name, module_.get());
global_symbol.operator std::string(), module_.get());
function_->setCallingConv(llvm::CallingConv::C);
function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
// set var map and align information
auto arg_it = function_->arg_begin();
for (size_t i = 0; i < f->args.size(); ++i, ++arg_it) {
for (size_t i = 0; i < f->params.size(); ++i, ++arg_it) {
llvm::Argument* v = &(*arg_it);
const Var& var = f->args[i];
const Var& var = f->params[i];
var_map_[var.get()] = v;
if (is_restricted_) {
if (var.dtype().is_handle() && !alias_var_set_.count(var.get())) {
......@@ -157,6 +160,7 @@ void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) {
llvm::BasicBlock* entry = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(entry);
this->VisitStmt(f->body);
if (ret_void) {
builder_->CreateRetVoid();
} else {
......@@ -295,33 +299,51 @@ int CodeGenLLVM::NativeVectorBits(const runtime::StorageScope& storage_scope) co
return native_vector_bits_;
}
unsigned CodeGenLLVM::GetGlobalAddressSpace() {
unsigned CodeGenLLVM::GetGlobalAddressSpace() const {
return 0;
}
llvm::Type* CodeGenLLVM::LLVMType(const DataType& t) const {
if (t.is_handle()) {
CHECK_EQ(t.lanes(), 1);
llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
if (dtype.is_handle()) {
CHECK_EQ(dtype.lanes(), 1);
return t_void_p_;
}
llvm::Type* etype = nullptr;
if (t.is_int() || t.is_uint()) {
etype = llvm::Type::getIntNTy(*ctx_, t.bits());
} else if (t.is_float()) {
switch (t.bits()) {
if (dtype.is_int() || dtype.is_uint()) {
etype = llvm::Type::getIntNTy(*ctx_, dtype.bits());
} else if (dtype.is_float()) {
switch (dtype.bits()) {
case 16: etype = llvm::Type::getHalfTy(*ctx_); break;
case 32: etype = llvm::Type::getFloatTy(*ctx_); break;
case 64: etype = llvm::Type::getDoubleTy(*ctx_); break;
default: LOG(FATAL) << "do not support " << t;
default: LOG(FATAL) << "do not support " << dtype;
}
}
if (t.lanes() != 1) {
return llvm::VectorType::get(etype, t.lanes());
if (dtype.lanes() != 1) {
return llvm::VectorType::get(etype, dtype.lanes());
} else {
return etype;
}
}
llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const {
if (auto* ptr = type.as<PrimTypeNode>()) {
return DTypeToLLVMType(ptr->dtype);
} else if (auto* ptr = type.as<PointerTypeNode>()) {
// TODO(tvm-team) consider put storage scope into the pointer type.
return GetLLVMType(ptr->element_type)->getPointerTo(GetGlobalAddressSpace());
} else if (IsVoidType(type)) {
return t_void_;
} else {
LOG(FATAL) << "Type " << type << " does not have a corresponding LLVM Type";
return t_void_;
}
}
llvm::Type* CodeGenLLVM::GetLLVMType(const PrimExpr& expr) const {
return GetLLVMType(GetType(expr));
}
// Add tbaa alias information for load
//
// use a binary tree typed system to declare information
......@@ -471,7 +493,8 @@ llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
}
llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
llvm::Value* mask = llvm::UndefValue::get(LLVMType(DataType::Int(32, target_lanes)));
llvm::Value* mask = llvm::UndefValue::get(
DTypeToLLVMType(DataType::Int(32, target_lanes)));
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
if (num_elems == target_lanes) return vec;
CHECK_LT(num_elems, target_lanes);
......@@ -552,16 +575,16 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin,
// cast operatpr
llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) {
llvm::Type * target = LLVMType(to);
llvm::Type * target = DTypeToLLVMType(to);
if (value->getType() == target) return value;
if (to.is_handle()) {
return builder_->CreateBitCast(value, target);
} else if (to.is_uint() && to.bits() == 1) {
if (from.is_float()) {
llvm::Constant* zero = llvm::ConstantFP::get(LLVMType(from), 0.);
llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.);
return builder_->CreateFCmpONE(value, zero);
} else {
llvm::Constant* zero = llvm::ConstantInt::get(LLVMType(from), 0);
llvm::Constant* zero = llvm::ConstantInt::get(DTypeToLLVMType(from), 0);
return builder_->CreateICmpNE(value, zero);
}
} else if (!from.is_float() && !to.is_float()) {
......@@ -570,7 +593,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va
return builder_->CreateFPToSI(value, target);
} else if (from.is_float() && to.is_uint()) {
if (to.bits() < 8) {
value = builder_->CreateFPToUI(value, LLVMType(to.with_bits(8)));
value = builder_->CreateFPToUI(value, DTypeToLLVMType(to.with_bits(8)));
return builder_->CreateIntCast(value, target, false);
} else {
return builder_->CreateFPToUI(value, target);
......@@ -610,7 +633,7 @@ llvm::Value* CodeGenLLVM::CreateBufferPtr(
CHECK_EQ(t.lanes(), 1);
llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
CHECK(btype != nullptr);
llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace());
llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo(btype->getAddressSpace());
if (btype != ptype) {
buffer = builder_->CreatePointerCast(buffer, ptype);
}
......@@ -623,7 +646,8 @@ llvm::Value* CodeGenLLVM::CreateBufferVecPtr(
CHECK_GT(t.lanes(), 1);
llvm::PointerType* btype = llvm::dyn_cast<llvm::PointerType>(buffer->getType());
CHECK(btype != nullptr);
llvm::PointerType* ptype = LLVMType(t)->getPointerTo(btype->getAddressSpace());
llvm::PointerType* ptype = DTypeToLLVMType(t)->getPointerTo(
btype->getAddressSpace());
if (btype != ptype) {
buffer = builder_->CreatePointerCast(buffer, ptype);
}
......@@ -644,7 +668,7 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const CallNode* op) {
arg_type.push_back(arg_value.back()->getType());
}
llvm::FunctionType* ftype = llvm::FunctionType::get(
LLVMType(op->dtype), arg_type, false);
GetLLVMType(GetRef<PrimExpr>(op)), arg_type, false);
llvm::Function* f = module_->getFunction(op->name);
if (f == nullptr) {
f = llvm::Function::Create(
......@@ -669,7 +693,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
sig_type.push_back(arg_value.back()->getType());
}
}
llvm::Type *return_type = LLVMType(op->dtype);
llvm::Type *return_type = GetLLVMType(GetRef<PrimExpr>(op));
if (sig_type.size() > 0 && return_type != sig_type[0]) {
sig_type.insert(sig_type.begin(), return_type);
}
......@@ -722,7 +746,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
uint64_t val = (high << 32U) | low;
return llvm::ConstantInt::get(LLVMType(op->dtype), val);
return llvm::ConstantInt::get(DTypeToLLVMType(op->dtype), val);
} else if (op->is_intrinsic(intrinsic::tvm_if_then_else)) {
CHECK_EQ(op->args[0].dtype().lanes(), 1)
<< "if_then_else can only take scalar condition";
......@@ -748,7 +772,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
value->addIncoming(else_value, else_value_block);
return value;
} else if (op->is_intrinsic(CallNode::reinterpret)) {
llvm::Type * target = LLVMType(op->dtype);
llvm::Type * target = DTypeToLLVMType(op->dtype);
return builder_->CreateBitCast(MakeValue(op->args[0]), target);
} else if (op->is_intrinsic(CallNode::isnan)) {
// TODO(hgt312): set fast math flag
......@@ -802,11 +826,11 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CastNode* op) {
return CreateCast(op->value.dtype(), op->dtype, MakeValue(op->value));
}
llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) {
return llvm::ConstantInt::getSigned(LLVMType(op->dtype), op->value);
return llvm::ConstantInt::getSigned(DTypeToLLVMType(op->dtype), op->value);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) {
return llvm::ConstantFP::get(LLVMType(op->dtype), op->value);
return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) {
......@@ -970,7 +994,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
CHECK_EQ(ramp->lanes, t.lanes());
llvm::Value* ptr = CreateBufferPtr(
t.element_of(), buffer, MakeValue(ramp->base));
ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace));
ptr = builder_->CreatePointerCast(
ptr, DTypeToLLVMType(t)->getPointerTo(addrspace));
llvm::LoadInst* load = builder_->CreateAlignedLoad(ptr, alignment, is_volatile);
AddAliasInfo(load, op->buffer_var.get(), op->index, t);
return load;
......@@ -979,7 +1004,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const LoadNode* op) {
}
// scalarized load.
int basic_align = t.bits() / 8;
llvm::Value* ret = llvm::UndefValue::get(LLVMType(t));
llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(t));
auto f = [&](int i, llvm::Value* index) {
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buffer, index);
llvm::LoadInst* load = builder_->CreateAlignedLoad(
......@@ -1007,7 +1032,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
}
llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) {
llvm::Value* vec = llvm::UndefValue::get(LLVMType(op->dtype));
llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype));
for (int i = 0; i < op->lanes; ++i) {
vec = builder_->CreateInsertElement(
vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(), i)),
......@@ -1066,7 +1091,7 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
CHECK_EQ(ramp->lanes, t.lanes());
llvm::Value* ptr = CreateBufferPtr(
t.element_of(), buffer, MakeValue(ramp->base));
ptr = builder_->CreatePointerCast(ptr, LLVMType(t)->getPointerTo(addrspace));
ptr = builder_->CreatePointerCast(ptr, DTypeToLLVMType(t)->getPointerTo(addrspace));
llvm::StoreInst* store = builder_->CreateAlignedStore(value, ptr, alignment, is_volatile);
AddAliasInfo(store, op->buffer_var.get(), op->index, op->value.dtype());
return;
......@@ -1147,7 +1172,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
}
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
return builder_->CreateAlloca(
LLVMType(op->dtype), ConstInt32(constant_size));
DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
});
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100
......@@ -1160,7 +1185,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
buf = alloca;
}
buf = builder_->CreatePointerCast(
buf, LLVMType(op->dtype)->getPointerTo(
buf, DTypeToLLVMType(op->dtype)->getPointerTo(
buf->getType()->getPointerAddressSpace()));
CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
......@@ -1168,7 +1193,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) {
}
void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent) {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag.length() != 0) {
if (!var_map_.count(iv->var.get())) {
......
......@@ -25,12 +25,17 @@
#define TVM_TARGET_LLVM_CODEGEN_LLVM_H_
#ifdef TVM_LLVM_VERSION
#include <tvm/ir/module.h>
#include <tvm/runtime/container.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/op.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/target/codegen.h>
#include <memory>
#include <utility>
#include <vector>
......@@ -78,7 +83,7 @@ class CodeGenLLVM :
* \brief Compile and add function f to the current module.
* \param f The function to be added.
*/
virtual void AddFunction(const LoweredFunc& f);
virtual void AddFunction(const PrimFunc& f);
/*!
* \brief Add main function as the entry name
* \param entry_func_name The name of entry function to be added.
......@@ -167,7 +172,7 @@ class CodeGenLLVM :
* \return The result.
*/
template<typename F>
inline llvm::AllocaInst* WithFunctionEntry(F falloca) {
llvm::AllocaInst* WithFunctionEntry(F falloca) {
llvm::BasicBlock* current = builder_->GetInsertBlock();
llvm::BasicBlock* entry = &(function_->getEntryBlock());
builder_->SetInsertPoint(entry, entry->begin());
......@@ -198,18 +203,35 @@ class CodeGenLLVM :
// Get the maximim storage align bits of buffer pointer given storage scope.
virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const;
// Get correct address space depending on the backend
virtual unsigned GetGlobalAddressSpace();
void AddFunctionInternal(const LoweredFunc& f, bool ret_void);
virtual unsigned GetGlobalAddressSpace() const;
void AddFunctionInternal(const PrimFunc& f, bool ret_void);
// Create extern call
llvm::CallInst* CreateCallExtern(llvm::Type* ret,
const std::string& name,
const std::vector<llvm::Value*>& value);
/*!
* \param t The original type.
* \return LLVM type of t
* \brief Get the LLVM Type for a given runtime type.
* \param dtype The runtime dtype.
*
* \note Only use this function for dealing with PrimTypes.
* For Call and Var that could have more refined types,
* use GetLLVMType instead.
*
* \return LLVM type of dtype
*/
llvm::Type* DTypeToLLVMType(const DataType& dtype) const;
/*!
* \brief Get the LLVM Type for a given type.
* \param dtype The runtime dtype.
* \param type The corresponding TVM Type.
*/
llvm::Type* GetLLVMType(const Type& type) const;
/*!
* \brief Get the LLVM Type for a given type.
* \param dtype The runtime dtype.
* \param type The corresponding TVM Type.
*/
llvm::Type* LLVMType(const DataType& t) const;
llvm::Type* GetLLVMType(const PrimExpr& expr) const;
// initialize the function state.
void InitFuncState();
// Get alignment given index.
......
......@@ -34,7 +34,7 @@ namespace codegen {
// NVPTX code generator.
class CodeGenNVPTX : public CodeGenLLVM {
public:
void AddFunction(const LoweredFunc& f) final {
void AddFunction(const PrimFunc& f) final {
// add function as void return value
CodeGenLLVM::AddFunctionInternal(f, true);
// annotate as kernel function
......@@ -68,7 +68,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
return builder_->CreateAlloca(
LLVMType(op->dtype), ConstInt32(constant_size));
DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
});
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100
......@@ -83,7 +83,8 @@ class CodeGenNVPTX : public CodeGenLLVM {
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
llvm::Type* type = llvm::ArrayType::get(LLVMType(op->dtype), constant_size);
llvm::Type* type = llvm::ArrayType::get(
DTypeToLLVMType(op->dtype), constant_size);
// Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
......@@ -97,7 +98,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
}
}
buf = builder_->CreatePointerCast(
buf, LLVMType(op->dtype)->getPointerTo(
buf, DTypeToLLVMType(op->dtype)->getPointerTo(
buf->getType()->getPointerAddressSpace()));
CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf;
......@@ -190,7 +191,7 @@ inline int DetectCUDAComputeVersion() {
}
}
runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
runtime::Module BuildNVPTX(IRModule mod, std::string target) {
InitializeLLVM();
CHECK(target.length() >= 5 &&
target.substr(0, 5) == "nvptx");
......@@ -202,8 +203,13 @@ runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
cg->Init(funcs[0]->name, tm.get(), ctx.get(), false, false);
for (LoweredFunc f : funcs) {
cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false);
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<PrimFunc>(kv.second);
cg->AddFunction(f);
}
......@@ -249,10 +255,10 @@ runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
#endif
pass.run(*module);
std::string ptx(data_ptx.begin(), data_ptx.end());
return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(funcs), ll);
return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(mod), ll);
}
TVM_REGISTER_GLOBAL("codegen.build_nvptx")
TVM_REGISTER_GLOBAL("target.build.nvptx")
.set_body_typed(BuildNVPTX);
} // namespace codegen
......
......@@ -88,7 +88,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
if (from.lanes() >= 16 && has_avx512) {
return CallVectorIntrin(
::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16,
LLVMType(DataType::Float(32, from.lanes())),
DTypeToLLVMType(DataType::Float(32, from.lanes())),
{
MakeValue(tir::CallNode::make(
DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value},
......@@ -103,7 +103,8 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
if (from.lanes() >= 8 && has_f16c) {
return CallVectorIntrin(
::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(DataType::Float(32, from.lanes())),
::llvm::Intrinsic::x86_vcvtph2ps_256, 8,
DTypeToLLVMType(DataType::Float(32, from.lanes())),
{MakeValue(tir::CallNode::make(
DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value},
tir::CallNode::PureIntrinsic))});
......
......@@ -25,6 +25,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir/module.h>
#include <tvm/target/codegen.h>
#include <mutex>
#include "llvm_common.h"
......@@ -192,21 +193,39 @@ class LLVMModuleNode final : public runtime::ModuleNode {
return "";
}
void Init(const Array<LoweredFunc>& funcs, std::string target) {
void Init(const IRModule& mod, std::string target) {
InitializeLLVM();
tm_ = GetLLVMTargetMachine(target);
bool system_lib = (target.find("-system-lib") != std::string::npos);
CHECK_NE(funcs.size(), 0U);
ctx_ = std::make_shared<llvm::LLVMContext>();
std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_.get());
entry_func_ = funcs[0]->name;
cg->Init(funcs[0]->name, tm_.get(), ctx_.get(), system_lib, system_lib);
for (LoweredFunc f : funcs) {
std::vector<PrimFunc> funcs;
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<PrimFunc>(kv.second);
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined());
entry_func_ = global_symbol;
}
funcs.push_back(f);
}
CHECK_NE(funcs.size(), 0U);
// TODO(tqchen): remove the entry function behavior as it does not
// makes sense when we start to use multiple modules.
cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib);
for (const auto& f : funcs) {
cg->AddFunction(f);
}
cg->AddMainFunction(funcs[0]->name);
module_ = cg->Finish();
if (entry_func_.length() != 0) {
cg->AddMainFunction(entry_func_);
}
module_ = cg->Finish();
module_->addModuleFlag(llvm::Module::Warning, "tvm_target", llvm::MDString::get(*ctx_, target));
module_->addModuleFlag(llvm::Module::Override, "Debug Info Version",
llvm::DEBUG_METADATA_VERSION);
......@@ -349,12 +368,14 @@ unsigned LookupLLVMIntrinsic(const std::string& name) {
return llvm::Function::lookupIntrinsicID(name);
}
TVM_REGISTER_GLOBAL("codegen.build_llvm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
auto n = make_object<LLVMModuleNode>();
n->Init(args[0].operator Array<LoweredFunc>(), args[1].operator std::string());
*rv = runtime::Module(n);
});
TVM_REGISTER_GLOBAL("target.build.llvm")
.set_body_typed([](IRModule mod, std::string target) {
auto n = make_object<LLVMModuleNode>();
n->Init(mod, target);
return runtime::Module(n);
});
TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
.set_body([](TVMArgs args, TVMRetValue *rv) {
......
......@@ -127,15 +127,23 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) {
return ptx;
}
runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
runtime::Module BuildCUDA(IRModule mod) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenCUDA cg;
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
......@@ -151,10 +159,10 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
} else {
ptx = NVRTCCompile(code, cg.need_include_path());
}
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(funcs), code);
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
}
TVM_REGISTER_GLOBAL("codegen.build_cuda")
TVM_REGISTER_GLOBAL("target.build.cuda")
.set_body_typed(BuildCUDA);
} // namespace codegen
} // namespace tvm
......@@ -31,16 +31,26 @@
namespace tvm {
namespace codegen {
runtime::Module BuildAOCL(Array<LoweredFunc> funcs, std::string target_str,
runtime::Module BuildAOCL(IRModule mod,
std::string target_str,
bool emulation) {
// Get code.
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenOpenCL cg;
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodegenOpenCL: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) {
code = (*f)(code).operator std::string();
......@@ -68,15 +78,15 @@ runtime::Module BuildAOCL(Array<LoweredFunc> funcs, std::string target_str,
std::string aocxbin;
runtime::LoadBinaryFromFile("aocl.aocx", &aocxbin);
return AOCLModuleCreate(aocxbin, "aocx", ExtractFuncInfo(funcs), code);
return AOCLModuleCreate(aocxbin, "aocx", ExtractFuncInfo(mod), code);
}
TVM_REGISTER_GLOBAL("codegen.build_aocl")
TVM_REGISTER_GLOBAL("target.build.aocl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildAOCL(args[0], args[1], false);
});
TVM_REGISTER_GLOBAL("codegen.build_aocl_sw_emu")
TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildAOCL(args[0], args[1], true);
});
......
......@@ -35,7 +35,7 @@ void CodeGenC::Init(bool output_ssa) {
print_ssa_form_ = output_ssa;
}
void CodeGenC::InitFuncState(LoweredFunc f) {
void CodeGenC::InitFuncState(const PrimFunc& f) {
alloc_storage_scope_.clear();
handle_data_type_.clear();
CodeGenSourceBase::ClearFuncState();
......@@ -72,39 +72,46 @@ void CodeGenC::ReserveKeywordsAsUnique() {
GetUniqueName("return");
}
void CodeGenC::AddFunction(LoweredFunc f) {
void CodeGenC::AddFunction(const PrimFunc& f) {
// clear previous generated state.
this->InitFuncState(f);
// reserve keywords
ReserveKeywordsAsUnique();
// add to alloc buffer type.
for (const auto & kv : f->handle_data_type) {
RegisterHandleType(kv.first.get(), kv.second.dtype());
}
this->stream << "void " << f->name << "(";
for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i];
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
this->PrintFuncPrefix();
this->stream << " " << static_cast<std::string>(global_symbol) << "(";
for (size_t i = 0; i < f->params.size(); ++i) {
tir::Var v = f->params[i];
std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", ";
if (v.dtype().is_handle()) {
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end())
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
stream << ' ';
stream << ' ';
}
if (handle_data_type_.count(v.get())) {
PrintType(handle_data_type_.at(v.get()), stream);
} else {
stream << "void";
PrintType(GetType(v), stream);
// Register handle data type
// TODO(tvm-team): consider simply keep type info in the
// type annotation(via a normalizing rewriting).
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(v.get(), prim->dtype);
}
}
stream << "*";
if (f->is_restricted && restrict_keyword_.length() != 0) {
if (no_alias && restrict_keyword_.length() != 0) {
stream << ' ' << restrict_keyword_;
}
} else {
PrintType(v.dtype(), stream);
PrintType(GetType(v), stream);
}
stream << ' ' << vid;
}
......@@ -112,11 +119,19 @@ void CodeGenC::AddFunction(LoweredFunc f) {
this->PreFunctionBody(f);
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
this->PrintFinalReturn();
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
}
void CodeGenC::PrintFuncPrefix() {
stream << "void";
}
void CodeGenC::PrintFinalReturn() {
}
std::string CodeGenC::Finish() {
return decl_stream.str() + stream.str();
}
......@@ -275,7 +290,6 @@ std::string CodeGenC::GetStructRef(
}
}
bool CodeGenC::HandleTypeMatch(const VarNode* buf_var, DataType t) const {
auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) return false;
......@@ -370,6 +384,20 @@ void CodeGenC::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
}
void CodeGenC::PrintType(const Type& type, std::ostream& os) { // NOLINT(*)
if (auto* ptr = type.as<PrimTypeNode>()) {
return PrintType(ptr->dtype, os);
} else if (auto* ptr = type.as<PointerTypeNode>()) {
PrintType(ptr->element_type, os);
os << '*';
} else if (IsVoidType(type)) {
os << "void";
} else {
LOG(FATAL) << "Type " << type << " does not have a corresponding C Type";
}
}
inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
if (op->dtype == DataType::Int(32)) {
std::ostringstream temp;
......
......@@ -26,9 +26,11 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/runtime/container.h>
#include <string>
#include <vector>
#include <unordered_map>
......@@ -62,8 +64,9 @@ class CodeGenC :
/*!
* \brief Add the function to the generated module.
* \param f The function to be compiled.
* \param whether to append return 0 in the end.
*/
void AddFunction(LoweredFunc f);
void AddFunction(const PrimFunc& f);
/*!
* \brief Finalize the compilation and return the code.
* \return The code.
......@@ -93,15 +96,25 @@ class CodeGenC :
}
// The following parts are overloadable print operations.
/*!
* \brief Print the function header before the argument list
*
* Example: stream << "void";
*/
virtual void PrintFuncPrefix(); // NOLINT(*)
/*!
* \brief Print the final return at the end the function.
*/
virtual void PrintFinalReturn(); // NOLINT(*)
/*!
* \brief Insert statement before function body.
* \param f The function to be compiled.
*/
virtual void PreFunctionBody(LoweredFunc f) {}
virtual void PreFunctionBody(const PrimFunc& f) {}
/*!
* \brief Initialize codegen state for generating f.
* \param f The function to be compiled.
*/
virtual void InitFuncState(LoweredFunc f);
virtual void InitFuncState(const PrimFunc& f);
// expression
void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*)
......@@ -149,6 +162,12 @@ class CodeGenC :
*/
virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*)
/*!
* Print Type represetnation of type type.
* \param type The type representation.
* \param os The stream to print the ctype into
*/
virtual void PrintType(const Type& type, std::ostream& os); // NOLINT(*)
/*!
* \brief Print expr representing the thread tag
* \param IterVar iv The thread index to be binded;
*/
......@@ -223,12 +242,6 @@ class CodeGenC :
// override
void PrintSSAAssign(
const std::string& target, const std::string& src, DataType t) final;
/*! \brief restrict keyword */
std::string restrict_keyword_{""};
/*! \brief the storage scope of allocation */
std::unordered_map<const VarNode*, std::string> alloc_storage_scope_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const VarNode*, DataType> handle_data_type_;
/*! \brief reserves common C keywords */
void ReserveKeywordsAsUnique();
......@@ -237,6 +250,13 @@ class CodeGenC :
return volatile_buf_.count(buf_var) != 0;
}
/*! \brief restrict keyword */
std::string restrict_keyword_{""};
/*! \brief the storage scope of allocation */
std::unordered_map<const VarNode*, std::string> alloc_storage_scope_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const VarNode*, DataType> handle_data_type_;
private:
/*! \brief whether to print in SSA form */
bool print_ssa_form_{false};
......
......@@ -41,59 +41,16 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts) {
CodeGenC::Init(output_ssa);
}
void CodeGenCHost::AddFunction(LoweredFunc f) {
// clear previous generated state.
this->InitFuncState(f);
// reserve keywords
ReserveKeywordsAsUnique();
// add to alloc buffer type.
for (const auto & kv : f->handle_data_type) {
RegisterHandleType(kv.first.get(), kv.second.dtype());
}
this->stream << "#ifdef __cplusplus\n";
this->stream << "extern \"C\"\n";
this->stream << "#endif\n";
this->stream << "TVM_DLL int32_t " << f->name << "(";
for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i];
std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", ";
if (v.dtype().is_handle()) {
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
}
stream << ' ';
if (handle_data_type_.count(v.get())) {
PrintType(handle_data_type_.at(v.get()), stream);
} else {
stream << "void";
}
stream << "*";
if (f->is_restricted && restrict_keyword_.length() != 0) {
stream << ' ' << restrict_keyword_;
}
} else {
PrintType(v.dtype(), stream);
}
stream << ' ' << vid;
}
stream << ") {\n";
this->PreFunctionBody(f);
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
this->PrintIndent();
this->stream << "return 0;\n";
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
void CodeGenCHost::PrintFuncPrefix() { // NOLINT(*)
stream << "#ifdef __cplusplus\n"
<< "extern \"C\"\n"
<< "#endif\n"
<< "TVM_DLL int32_t";
}
std::string CodeGenCHost::Finish() {
return CodeGenC::Finish();
void CodeGenCHost::PrintFinalReturn() { // NOLINT(*)
this->PrintIndent();
stream << "return 0;\n";
}
void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
......@@ -277,20 +234,25 @@ inline void CodeGenCHost::PrintTernaryCondExpr(const T* op,
<< "? (" << a_id << ") : (" << b_id << "))";
}
runtime::Module BuildCHost(Array<LoweredFunc> funcs) {
runtime::Module BuildCHost(IRModule mod) {
using tvm::runtime::Registry;
bool output_ssa = false;
bool emit_asserts = false;
CodeGenCHost cg;
cg.Init(output_ssa, emit_asserts);
for (LoweredFunc f : funcs) {
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodegenCHost: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
cg.AddFunction(f);
}
std::string code = cg.Finish();
return CSourceModuleCreate(code, "c");
}
TVM_REGISTER_GLOBAL("codegen.build_c")
TVM_REGISTER_GLOBAL("target.build.c")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildCHost(args[0]);
});
......
......@@ -36,10 +36,10 @@ class CodeGenCHost final : public CodeGenC {
public:
CodeGenCHost();
void Init(bool output_ssa, bool emit_asserts);
void AddFunction(LoweredFunc f);
std::string Finish();
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintFuncPrefix() final; // NOLINT(*)
void PrintFinalReturn() final; // NOLINT(*)
// overload visitor functions
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
......
......@@ -43,9 +43,9 @@ void CodeGenCUDA::Init(bool output_ssa) {
CHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state);
}
void CodeGenCUDA::AddFunction(LoweredFunc f) {
this->stream << "extern \"C\" __global__ ";
CodeGenC::AddFunction(f);
void CodeGenCUDA::PrintFuncPrefix() {
stream << "extern \"C\" __global__ void";
}
std::string CodeGenCUDA::Finish() {
......@@ -424,11 +424,11 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
}
void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::fragment_shape) {
if (op->attr_key == tir::attr::fragment_shape) {
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* shape_str = op->value.as<StringImmNode>();
fragment_shapes[buffer] = shape_str->value;
} else if (op->attr_key == attr::fragment_layout) {
} else if (op->attr_key == tir::attr::fragment_layout) {
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* layout_str = op->value.as<StringImmNode>();
fragment_layouts[buffer] = layout_str->value;
......
......@@ -37,12 +37,12 @@ class CodeGenCUDA final : public CodeGenC {
public:
CodeGenCUDA();
void Init(bool output_ssa);
void AddFunction(LoweredFunc f);
std::string Finish();
bool need_include_path() {
return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
}
// override behavior
void PrintFuncPrefix() final;
void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
......
......@@ -31,10 +31,10 @@
namespace tvm {
namespace codegen {
void CodeGenMetal::InitFuncState(LoweredFunc f) {
void CodeGenMetal::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f);
// analyze the data;
for (Var arg : f->args) {
for (Var arg : f->params) {
if (arg.dtype().is_handle()) {
alloc_storage_scope_[arg.get()] = "global";
}
......@@ -49,48 +49,55 @@ CodeGenMetal::CodeGenMetal() {
<< "};\n\n";
}
void CodeGenMetal::AddFunction(LoweredFunc f) {
void CodeGenMetal::AddFunction(const PrimFunc& f) {
// clear previous generated state.
this->InitFuncState(f);
// skip the first underscore, so SSA variable starts from _1
GetUniqueName("_");
// add to alloc buffer type.
for (const auto & kv : f->handle_data_type) {
RegisterHandleType(kv.first.get(), kv.second.dtype());
}
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
// Function header.
this->stream << "kernel void " << f->name << "(\n";
this->stream << "kernel void " << static_cast<std::string>(global_symbol) << "(";
// Buffer arguments
size_t num_buffer = 0;
for (size_t i = 0; i < f->args.size(); ++i, ++num_buffer) {
Var v = f->args[i];
for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) {
Var v = f->params[i];
if (!v.dtype().is_handle()) break;
stream << " ";
std::string vid = AllocVarID(v.get());
auto it = alloc_storage_scope_.find(v.get());
CHECK(it != alloc_storage_scope_.end());
PrintStorageScope(it->second, stream);
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
}
stream << ' ';
if (handle_data_type_.count(v.get())) {
PrintType(handle_data_type_.at(v.get()), stream);
stream << "*";
} else {
PrintType(v.dtype(), stream);
PrintType(GetType(v), stream);
// Register handle data type
// TODO(tvm-team): consider simply keep type info in the
// type annotation(via a normalizing rewriting).
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(v.get(), prim->dtype);
}
}
stream << ' ' << vid
<< " [[ buffer(" << i << ") ]],\n";
}
// Setup normal arguments.
size_t nargs = f->args.size() - num_buffer;
size_t nargs = f->params.size() - num_buffer;
std::string varg = GetUniqueName("arg");
if (nargs != 0) {
std::string arg_buf_type = f->name + "_args_t";
std::string arg_buf_type = static_cast<std::string>(global_symbol) + "_args_t";
stream << " constant " << arg_buf_type << "& " << varg
<< " [[ buffer(" << num_buffer << ") ]],\n";
// declare the struct
decl_stream << "struct " << arg_buf_type << " {\n";
for (size_t i = num_buffer; i < f->args.size(); ++i) {
Var v = f->args[i];
for (size_t i = num_buffer; i < f->params.size(); ++i) {
Var v = f->params[i];
CHECK(!v.dtype().is_handle());
std::string vid = AllocVarID(v.get());
std::ostringstream vref;
......@@ -113,7 +120,10 @@ void CodeGenMetal::AddFunction(LoweredFunc f) {
CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx");
CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx");
int work_dim = 0;
for (IterVar iv : f->thread_axis) {
auto thread_axis = f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis);
CHECK(thread_axis.defined());
for (IterVar iv : thread_axis) {
runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
work_dim = std::max(work_dim, scope.dim_index + 1);
}
......@@ -127,7 +137,7 @@ void CodeGenMetal::AddFunction(LoweredFunc f) {
stream << " threadIdx [[thread_position_in_threadgroup]]\n";
}
// bind thread axis
for (IterVar iv : f->thread_axis) {
for (IterVar iv : thread_axis) {
CHECK(!var_idmap_.count(iv->var.get()));
std::string vname = iv->thread_tag;
if (work_dim <= 1) {
......@@ -257,14 +267,23 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT
}
}
runtime::Module BuildMetal(Array<LoweredFunc> funcs) {
runtime::Module BuildMetal(IRModule mod) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenMetal cg;
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenMetal: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
}
std::string code = cg.Finish();
std::string fmt = "metal";
std::string source = "";
......@@ -273,10 +292,10 @@ runtime::Module BuildMetal(Array<LoweredFunc> funcs) {
code = (*f)(code).operator std::string();
fmt = "metallib";
}
return MetalModuleCreate(code, fmt, ExtractFuncInfo(funcs), source);
return MetalModuleCreate(code, fmt, ExtractFuncInfo(mod), source);
}
TVM_REGISTER_GLOBAL("codegen.build_metal")
TVM_REGISTER_GLOBAL("target.build.metal")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildMetal(args[0]);
});
......
......@@ -34,10 +34,10 @@ namespace codegen {
class CodeGenMetal final : public CodeGenC {
public:
CodeGenMetal();
void AddFunction(LoweredFunc f);
// override print thread tag.
void PrintArgUnionDecl();
void InitFuncState(LoweredFunc f) final;
void AddFunction(const PrimFunc& f); // NOLINT(*)
void InitFuncState(const PrimFunc& f) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
......@@ -50,9 +50,10 @@ class CodeGenMetal final : public CodeGenC {
const std::string& vec, DataType t, int i, const std::string& value) final;
// overload visitor
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
// overload visitor
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
// reuse parent's function.
using CodeGenC::PrintType;
private:
int thread_index_bits_{32};
......
......@@ -35,18 +35,17 @@ CodeGenOpenCL::CodeGenOpenCL() {
restrict_keyword_ = "restrict";
}
void CodeGenOpenCL::InitFuncState(LoweredFunc f) {
void CodeGenOpenCL::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f);
for (Var arg : f->args) {
for (Var arg : f->params) {
if (arg.dtype().is_handle()) {
alloc_storage_scope_[arg.get()] = "global";
}
}
}
void CodeGenOpenCL::AddFunction(LoweredFunc f) {
this->stream << "__kernel ";
CodeGenC::AddFunction(f);
void CodeGenOpenCL::PrintFuncPrefix() {
stream << "__kernel void";
}
std::string CodeGenOpenCL::Finish() {
......@@ -239,50 +238,31 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NO
}
}
template<typename T>
inline void PrintBinaryExpr(const T* op,
const char* opstr,
std::ostream& os,
CodeGenOpenCL* p) {
if (op->dtype.lanes() == 1) {
os << opstr << "((";
p->PrintType(op->a->dtype, os);
os << ")";
p->PrintExpr(op->a, os);
os << ", (";
p->PrintType(op->b->dtype, os);
os << ")";
p->PrintExpr(op->b, os);
os << ')';
} else {
p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os);
}
}
void CodeGenOpenCL::VisitExpr_(const MinNode *op, std::ostream& os) {
PrintBinaryExpr(op, "min", os, this);
}
void CodeGenOpenCL::VisitExpr_(const MaxNode *op, std::ostream& os) {
PrintBinaryExpr(op, "max", os, this);
}
runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
runtime::Module BuildOpenCL(IRModule mod) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenOpenCL cg;
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenOpenCL: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) {
code = (*f)(code).operator std::string();
}
return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs), code);
return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(mod), code);
}
TVM_REGISTER_GLOBAL("codegen.build_opencl")
TVM_REGISTER_GLOBAL("target.build.opencl")
.set_body_typed(BuildOpenCL);
} // namespace codegen
} // namespace tvm
......@@ -34,11 +34,11 @@ namespace codegen {
class CodeGenOpenCL final : public CodeGenC {
public:
CodeGenOpenCL();
void AddFunction(LoweredFunc f);
std::string Finish();
// override print thread tag.
void InitFuncState(LoweredFunc f) final;
void InitFuncState(const PrimFunc& f) final;
void PrintFuncPrefix() final; // NOLINT(*)
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
......@@ -56,9 +56,6 @@ class CodeGenOpenCL final : public CodeGenC {
// overload visitor
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; // NOLINT(*)
// overload min and max to avoid ambiguous call errors
void VisitExpr_(const MinNode *op, std::ostream& os) final;
void VisitExpr_(const MaxNode *op, std::ostream& os) final;
private:
// whether enable fp16 and fp64 extension
......
......@@ -37,7 +37,7 @@ namespace codegen {
CodeGenOpenGL::CodeGenOpenGL()
: output_(nullptr), output_iter_var_(nullptr) {}
void CodeGenOpenGL::InitFuncState(LoweredFunc f) {
void CodeGenOpenGL::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f);
output_ = nullptr;
inputs_.clear();
......@@ -47,7 +47,7 @@ void CodeGenOpenGL::InitFuncState(LoweredFunc f) {
this->stream.str("");
}
void CodeGenOpenGL::AddFunction(LoweredFunc f) {
void CodeGenOpenGL::AddFunction(const PrimFunc& f) {
// clear previous generated state.
this->InitFuncState(f);
......@@ -56,15 +56,17 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) {
// skip the first underscore, so SSA variable starts from _1
GetUniqueName("_");
// add to alloc buffer type.
for (const auto& kv : f->handle_data_type) {
RegisterHandleType(kv.first.get(), kv.second.dtype());
}
// Allocate argument names. Store in `var_idmap_`.
for (auto arg : f->args) {
for (auto arg : f->params) {
auto arg_name = GetUniqueName(arg.get()->name_hint);
var_idmap_[arg.get()] = arg_name;
if (auto* ptr = arg->type_annotation.as<PointerTypeNode>()) {
if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(arg.get(), prim->dtype);
}
}
}
thread_extent_var_ = GetUniqueName("thread_extent");
......@@ -80,7 +82,7 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) {
this->stream << "}\n\n";
// Declare arguments.
for (auto arg : f->args) {
for (auto arg : f->params) {
if (this->inputs_.find(arg.get()) != this->inputs_.cend()) {
// Declare input texture.
// Format:
......@@ -138,7 +140,7 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) {
std::vector<std::string> arg_names;
std::vector<runtime::OpenGLArgKind> arg_kinds;
for (auto arg : f->args) {
for (auto arg : f->params) {
std::string name = GetVarID(arg.get());
runtime::OpenGLArgKind kind;
......@@ -154,7 +156,11 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) {
arg_kinds.push_back(kind);
}
shaders_[f->name] = runtime::OpenGLShader(
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute";
shaders_[static_cast<std::string>(global_symbol)] = runtime::OpenGLShader(
this->decl_stream.str() + this->stream.str(),
std::move(arg_names), std::move(arg_kinds),
this->thread_extent_var_);
......@@ -283,18 +289,27 @@ void CodeGenOpenGL::VisitStmt_(const EvaluateNode* op) {
this->stream << GetVarID(buffer) << " = " << PrintExpr(value) << ";\n";
}
runtime::Module BuildOpenGL(Array<LoweredFunc> funcs) {
runtime::Module BuildOpenGL(IRModule mod) {
bool output_ssa = false;
CodeGenOpenGL cg;
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenOpenGL: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenOpenGL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
}
auto shaders = cg.Finish();
return OpenGLModuleCreate(shaders, "gl", ExtractFuncInfo(funcs));
return OpenGLModuleCreate(shaders, "gl", ExtractFuncInfo(mod));
}
TVM_REGISTER_GLOBAL("codegen.build_opengl")
TVM_REGISTER_GLOBAL("target.build.opengl")
.set_body_typed(BuildOpenGL);
} // namespace codegen
......
......@@ -37,10 +37,10 @@ namespace codegen {
class CodeGenOpenGL final : public CodeGenC {
public:
CodeGenOpenGL();
void AddFunction(LoweredFunc f);
std::unordered_map<std::string, runtime::OpenGLShader> Finish();
void InitFuncState(LoweredFunc f) final;
void AddFunction(const PrimFunc& f);
void InitFuncState(const PrimFunc& f) final;
void BindThreadIndex(const IterVar& iv) final;
void VisitStmt_(const StoreNode* op) final;
std::string TexelFetch(const VarNode* buffer, PrimExpr index);
......
......@@ -68,14 +68,13 @@ void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) {
}
}
void CodeGenVivadoHLS::AddFunction(LoweredFunc f) {
this->stream << "extern \"C\" ";
CodeGenC::AddFunction(f);
void CodeGenVivadoHLS::PrintFuncPrefix() {
stream << "extern \"C\" void";
}
void CodeGenVivadoHLS::PreFunctionBody(LoweredFunc f) {
for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i];
void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) {
for (size_t i = 0; i < f->params.size(); ++i) {
Var v = f->params[i];
std::string vid = GetVarID(v.get());
if (v.dtype().is_handle()) {
this->stream << "#pragma HLS INTERFACE m_axi port=" << vid << " offset=slave bundle=gmem\n";
......@@ -126,21 +125,34 @@ void CodeGenVivadoHLS::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOL
}
runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
runtime::Module BuildSDAccel(IRModule mod, std::string target_str) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenVivadoHLS cg;
// Generate source code for get_source().
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenVHLS: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
}
std::string whole_code = cg.Finish();
// Generate source code for compilation.
Array<Array<PrimExpr> > kernel_info;
for (LoweredFunc f : funcs) {
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenOpenCL: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
CodeGenVivadoHLS cg;
cg.Init(output_ssa);
cg.AddFunction(f);
......@@ -148,7 +160,12 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) {
code = (*f)(code).operator std::string();
}
kernel_info.push_back(Array<PrimExpr>({f->name, code}));
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
std::string func_name = global_symbol;
kernel_info.push_back(Array<PrimExpr>({func_name, code}));
}
std::string xclbin;
......@@ -158,10 +175,10 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
} else {
LOG(FATAL) << "Cannot compile Vivado HLS code.";
}
return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(funcs), whole_code);
return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(mod), whole_code);
}
TVM_REGISTER_GLOBAL("codegen.build_sdaccel")
TVM_REGISTER_GLOBAL("target.build.sdaccel")
.set_body_typed(BuildSDAccel);
} // namespace codegen
......
......@@ -15,7 +15,7 @@
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
5B5B */
/*!
* \file codegen_vhls.h
......@@ -37,10 +37,11 @@ class CodeGenVivadoHLS final : public CodeGenC {
public:
void Init(bool output_ssa);
void PrintType(DataType t, std::ostream& os);
void AddFunction(LoweredFunc f);
void PreFunctionBody(LoweredFunc f);
void VisitExpr_(const MinNode *op, std::ostream& os);
void VisitExpr_(const MaxNode *op, std::ostream& os);
void PrintFuncPrefix() final;
void PreFunctionBody(const PrimFunc& f) final;
void VisitExpr_(const MinNode *op, std::ostream& os) final;
void VisitExpr_(const MaxNode *op, std::ostream& os) final;
};
} // namespace codegen
......
......@@ -70,7 +70,7 @@ class SPIRVTools {
spv_context ctx_;
};
runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) {
runtime::Module BuildSPIRV(IRModule mod) {
using tvm::runtime::Registry;
using tvm::runtime::VulkanShader;
......@@ -81,8 +81,21 @@ runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) {
const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc");
CodeGenSPIRV cg;
for (LoweredFunc f : funcs) {
f = PointerValueTypeRewrite(f);
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenSPIRV: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
std::string f_name = global_symbol;
f = PointerValueTypeRewrite(std::move(f));
VulkanShader shader;
shader.data = cg.BuildFunction(f);
......@@ -97,13 +110,14 @@ runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) {
reinterpret_cast<char*>(dmlc::BeginPtr(shader.data)));
}
code_data << spirv_tools.BinaryToText(shader.data);
smap[f->name] = std::move(shader);
smap[f_name] = std::move(shader);
}
return runtime::VulkanModuleCreate(
smap, ExtractFuncInfo(funcs), code_data.str());
smap, ExtractFuncInfo(mod), code_data.str());
}
TVM_REGISTER_GLOBAL("codegen.build_vulkan")
TVM_REGISTER_GLOBAL("target.build.vulkan")
.set_body_typed(BuildSPIRV);
} // namespace codegen
......
......@@ -23,6 +23,7 @@
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/runtime/container.h>
#include <string>
#include "codegen_spirv.h"
#include "../../arith/compute_expr.h"
......@@ -30,18 +31,20 @@
namespace tvm {
namespace codegen {
std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const LoweredFunc& f) {
std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f) {
this->InitFuncState();
CHECK(f->is_restricted)
CHECK(f->HasNonzeroAttr(tir::attr::kNoAlias))
<< "SPIRV only takes restricted memory model";
std::vector<Var> pod_args;
uint32_t num_buffer = 0;
for (Var arg : f->args) {
for (Var arg : f->params) {
DataType t = arg.dtype();
if (t.is_handle()) {
auto it = f->handle_data_type.find(arg);
if (it != f->handle_data_type.end()) {
DataType value_type = (*it).second.dtype();
if (auto* ptr = arg->type_annotation.as<PointerTypeNode>()) {
auto* prim = ptr->element_type.as<PrimTypeNode>();
CHECK(prim);
DataType value_type = prim->dtype;
spirv::Value arg_value = builder_->BufferArgument(
builder_->GetSType(value_type), 0, num_buffer);
storage_info_[arg.get()].UpdateContentType(value_type);
......@@ -75,7 +78,11 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const LoweredFunc& f) {
builder_->MakeInst(spv::OpReturn);
builder_->MakeInst(spv::OpFunctionEnd);
builder_->CommitKernelFunction(func_ptr, f->name);
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
builder_->CommitKernelFunction(func_ptr, static_cast<std::string>(global_symbol));
return builder_->Finalize();
}
......@@ -607,7 +614,7 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
}
void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent) {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag.length() != 0) {
if (!var_map_.count(iv->var.get())) {
......
......@@ -53,7 +53,7 @@ class CodeGenSPIRV:
* \param f The function to be added.
* \return The final spirv module.
*/
virtual std::vector<uint32_t> BuildFunction(const LoweredFunc& f);
virtual std::vector<uint32_t> BuildFunction(const PrimFunc& f);
/*!
* \brief Create Value for expression e
* \param e The expression to be created value for.
......
......@@ -21,7 +21,10 @@
* \file codegen_stackvm.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/ir/module.h>
#include <tvm/tir/op.h>
#include <tvm/tir/function.h>
#include <limits>
#include <utility>
#include "codegen_stackvm.h"
......@@ -54,9 +57,9 @@ StackVM::StructFieldKind MapFieldKind(int64_t kind) {
return StackVM::kArrData;
}
StackVM CodeGenStackVM::Compile(LoweredFunc f) {
for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i];
StackVM CodeGenStackVM::Compile(const PrimFunc& f) {
for (size_t i = 0; i < f->params.size(); ++i) {
Var v = f->params[i];
int vid = AllocVarID(v.get());
CHECK_EQ(static_cast<size_t>(vid), i);
}
......@@ -525,19 +528,32 @@ void CodeGenStackVM::VisitExpr_(const LetNode* op) {
this->Push(op->body);
}
runtime::Module BuildStackVM(const Array<LoweredFunc>& funcs) {
CHECK_NE(funcs.size(), 0U);
runtime::Module BuildStackVM(const IRModule& mod) {
std::unordered_map<std::string, StackVM> fmap;
for (LoweredFunc f : funcs) {
std::string entry_func;
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenStackVM: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute";
std::string f_name = global_symbol;
StackVM vm = codegen::CodeGenStackVM().Compile(f);
CHECK(!fmap.count(f->name))
<< "Function name " << f->name << "already exist in list";
fmap[f->name] = std::move(vm);
CHECK(!fmap.count(f_name))
<< "Function name " << f_name << "already exist in list";
fmap[f_name] = std::move(vm);
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
entry_func = f_name;
}
}
return runtime::StackVMModuleCreate(fmap, funcs[0]->name);
return runtime::StackVMModuleCreate(fmap, entry_func);
}
TVM_REGISTER_GLOBAL("codegen.build_stackvm")
TVM_REGISTER_GLOBAL("target.build.stackvm")
.set_body_typed(BuildStackVM);
} // namespace codegen
} // namespace tvm
......@@ -56,7 +56,7 @@ class CodeGenStackVM
* \note Only call compile once,
* create a new codegen object each time.
*/
StackVM Compile(LoweredFunc f);
StackVM Compile(const PrimFunc& f);
/*! \brief Push stmt to generate new code */
void Push(const Stmt& n);
/*! \brief Push expr to generate new code */
......
......@@ -91,7 +91,7 @@ Stmt MakeCrossThreadReduction(
freduce_args, CallNode::Intrinsic));
reduce_body = AttrStmtNode::make(
reduces[0]->combiner,
attr::reduce_scope,
tir::attr::reduce_scope,
make_zero(DataType::Handle()),
reduce_body);
std::vector<Stmt> assigns(size);
......@@ -109,7 +109,7 @@ Stmt MakeCrossThreadReduction(
body = AllocateNode::make(
res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
body = AttrStmtNode::make(
res_handles[idx - 1], attr::storage_scope, StringImmNode::make("local"), body);
res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body);
}
body = Substitute(body, value_map);
return MergeNest(nest, body);
......
......@@ -165,7 +165,8 @@ Stmt ExternOpNode::BuildProvide(
const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
Stmt ret = AttrStmtNode::make(
make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<ObjectRef> bind_spec;
Array<PrimExpr> tuple;
......@@ -176,7 +177,7 @@ Stmt ExternOpNode::BuildProvide(
tuple.push_back(buffer->shape[k]);
}
ret = AttrStmtNode::make(
bind_spec, attr::buffer_bind_scope,
bind_spec, tir::attr::buffer_bind_scope,
CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret);
};
for (size_t i = output_placeholders.size(); i != 0; --i) {
......
......@@ -186,7 +186,8 @@ Stmt HybridOpNode::BuildProvide(
const std::unordered_map<IterVar, Range> &dom_map,
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body);
Stmt ret = AttrStmtNode::make(
make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
std::unordered_map<Tensor, Tensor> rmap;
for (int i = 0; i < this->num_outputs(); ++i) {
rmap[outputs[i]] = stage->op.output(i);
......
......@@ -168,7 +168,7 @@ MakeLoopNest(const Stage& stage,
// annotate the extent of the IterVar
if (!new_loop_var) {
nest[i + 1].emplace_back(
AttrStmtNode::make(iv, attr::loop_scope, iv->var, no_op));
AttrStmtNode::make(iv, tir::attr::loop_scope, iv->var, no_op));
}
}
// message passing to get offset of root iter vars.
......
......@@ -287,10 +287,10 @@ Stmt ScanOpNode::BuildProvide(
bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this);
Stmt provide = AttrStmtNode::make(
stage->op, attr::scan_update_scope, this->scan_axis->var,
stage->op, tir::attr::scan_update_scope, this->scan_axis->var,
EvaluateNode::make(0));
Stmt init = AttrStmtNode::make(
stage->op, attr::scan_init_scope, 0,
stage->op, tir::attr::scan_init_scope, 0,
EvaluateNode::make(0));
size_t begin_scan = 0;
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
......
......@@ -85,7 +85,7 @@ class InjectAttach : public StmtMutator {
auto stmt = StmtMutator::VisitStmt(input_stmt);
const AttrStmtNode* op = stmt.as<AttrStmtNode>();
if (op != nullptr &&
op->attr_key == attr::loop_scope) {
op->attr_key == tir::attr::loop_scope) {
if (attach_spec_->attach_type == kScope &&
op->node == attach_spec_->attach_ivar) {
CHECK(!found_attach)
......@@ -131,8 +131,8 @@ class InjectScanStep : public StmtMutator {
// update
const AttrStmtNode* op = stmt.as<AttrStmtNode>();
if (op != nullptr &&
((op->attr_key == attr::scan_update_scope && !is_init_) ||
(op->attr_key == attr::scan_init_scope && is_init_))) {
((op->attr_key == tir::attr::scan_update_scope && !is_init_) ||
(op->attr_key == tir::attr::scan_init_scope && is_init_))) {
if (op->node.same_as(scan_op_)) {
found_attach = true;
stmt = AttrStmtNode::make(
......@@ -187,15 +187,15 @@ class SchedulePostProc : public StmtExprMutator {
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::loop_scope ||
op->attr_key == attr::scan_init_scope) {
if (op->attr_key == tir::attr::loop_scope ||
op->attr_key == tir::attr::scan_init_scope) {
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::scan_update_scope) {
} else if (op->attr_key == tir::attr::scan_update_scope) {
const ScanOpNode* scan = op->node.as<ScanOpNode>();
CHECK(scan);
var_value_[scan->scan_axis->var.get()] = op->value;
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::thread_extent) {
} else if (op->attr_key == tir::attr::thread_extent) {
// delete duplicated thread extent attr
auto it = thread_extent_scope_.find(op->node.get());
if (it != thread_extent_scope_.end()) {
......
......@@ -32,25 +32,51 @@
namespace tvm {
namespace tir {
Var::Var(std::string name_hint, DataType t)
: Var(make_object<VarNode>(t, name_hint)) {}
Var::Var(std::string name_hint, DataType dtype) {
auto n = make_object<VarNode>();
n->name_hint = std::move(name_hint);
n->dtype = std::move(dtype);
data_ = std::move(n);
}
VarNode::VarNode(DataType t, std::string name_hint) {
this->dtype = t;
this->name_hint = std::move(name_hint);
Var::Var(std::string name_hint, Type type_annotation) {
auto n = make_object<VarNode>();
n->name_hint = std::move(name_hint);
n->dtype = GetRuntimeDataType(type_annotation);
n->type_annotation = std::move(type_annotation);
data_ = std::move(n);
}
SizeVar::SizeVar(std::string name_hint, DataType t)
: SizeVar(make_object<SizeVarNode>(t, name_hint)) {}
SizeVarNode::SizeVarNode(DataType t, std::string name_hint)
: VarNode(t, std::move(name_hint)) {}
Var Var::copy_with_suffix(const std::string& suffix) const {
const VarNode* node = get();
ObjectPtr<VarNode> new_ptr;
if (auto* ptr = this->as<SizeVarNode>()) {
new_ptr = make_object<SizeVarNode>(*ptr);
} else {
new_ptr = make_object<VarNode>(*node);
}
new_ptr->name_hint += suffix;
return Var(new_ptr);
}
SizeVar::SizeVar(std::string name_hint, DataType dtype) {
auto n = make_object<SizeVarNode>();
n->name_hint = std::move(name_hint);
n->dtype = std::move(dtype);
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("tir.Var")
.set_body_typed([](std::string s, DataType t) {
return Var(s, t);
});
.set_body_typed([](std::string name_hint, runtime::TVMArgValue type) {
if (type.IsObjectRef<Type>()) {
return Var(name_hint, type.operator Type());
} else {
return Var(name_hint, type.operator DataType());
}
});
TVM_REGISTER_GLOBAL("tir.SizeVar")
.set_body_typed([](std::string s, DataType t) {
......
......@@ -31,5 +31,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});
TVM_REGISTER_NODE_TYPE(LoweredFuncNode);
} // namespace tir
} // namespace tvm
......@@ -33,14 +33,34 @@ namespace tvm {
using namespace tir;
runtime::DataType GetRuntimeDataType(const Type& type) {
if (auto * n = type.as<PrimTypeNode>()) {
return n->dtype;
} else if (type.as<PointerTypeNode>()) {
return DataType::Handle();
} else {
LOG(FATAL) << "Type " << type
<< " does not have a corresponding runtime::DataType";
return DataType::Handle();
}
}
Type GetType(const PrimExpr& expr) {
// TODO(tqchen): add recursive type inference for Call here
// once we introduced the corresponding fields to the IR.
if (auto* ptr = expr.as<tir::VarNode>()) {
// If Var has a more refined type annotation,
// return the type anotation
if (ptr->type_annotation.defined()) {
return ptr->type_annotation;
}
}
// Default: return the type indicated by the dtype.
runtime::DataType dtype = expr.dtype();
// These types already implies the specific type.
if (dtype.is_int() || dtype.is_uint() || dtype.is_float()) {
return PrimType(dtype);
}
// TODO(tqchen): add recursive type inference for Var and Call here
// once we introduced the corresponding fields to the IR.
return PrimType(dtype);
}
......
......@@ -68,6 +68,32 @@ class IRSubstitue : public StmtExprMutator {
}
}
PrimExpr VisitExpr_(const LoadNode* op) final {
// NOTE: we do not explicit recursivly mutate op->buffer_var
PrimExpr ret = StmtExprMutator::VisitExpr_(op);
op = ret.as<LoadNode>();
auto it = smap_.find(op->buffer_var.get());
if (it != smap_.end()) {
return LoadNode::make(
op->dtype, Downcast<Var>(it->second), op->index, op->predicate);
} else {
return ret;
}
}
Stmt VisitStmt_(const StoreNode* op) final {
// NOTE: we do not explicit recursivly mutate op->buffer_var
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<StoreNode>();
auto it = smap_.find(op->buffer_var.get());
if (it != smap_.end()) {
return StoreNode::make(
Downcast<Var>(it->second), op->value, op->index, op->predicate);
} else {
return ret;
}
}
private:
const std::unordered_map<const VarNode*, PrimExpr>& smap_;
};
......
......@@ -1016,6 +1016,47 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
return LoweredFunc(n);
}
PrimFunc PointerValueTypeRewrite(PrimFunc f) {
auto* n = f.CopyOnWrite();
VectorAllocRewriter rewriter;
n->body = rewriter(n->body);
Array<tir::Var> args;
Map<tir::Var, PrimExpr> remap_vars;
for (Var var : f->params) {
if (var.dtype().is_handle()) {
const auto& tvec = rewriter.acc_map_[var.get()];
if (tvec.size() == 1) {
tir::Var new_var(var->name_hint,
PointerType(PrimType(tvec[0])));
args.push_back(new_var);
remap_vars.Set(var, new_var);
} else {
// always set data type to be non vectorized so
// load/store can still work via scalarization
if (tvec.size() != 0 && !var->type_annotation.defined()) {
tir::Var new_var(var->name_hint,
PointerType(PrimType(tvec[0].with_lanes(1))));
args.push_back(new_var);
remap_vars.Set(var, new_var);
} else {
args.push_back(var);
}
}
} else {
args.push_back(var);
}
}
CHECK_EQ(args.size(), n->params.size());
n->params = args;
n->body = Substitute(n->body, remap_vars);
return f;
}
Stmt StorageRewrite(Stmt stmt) {
stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true);
return VectorAllocRewriter()(std::move(stmt));
......
......@@ -83,8 +83,8 @@ def test_meta_data():
text_no_meta = str(f)
assert "channels=2" in text
assert "channels=2" in text_no_meta
assert "meta[SizeVar][0]" in text
assert "meta[SizeVar][0]" in text_no_meta
assert "meta[tir.SizeVar][0]" in text
assert "meta[tir.SizeVar][0]" in text_no_meta
assert "type_key" in text
assert "type_key" not in text_no_meta
......
......@@ -108,8 +108,32 @@ def test_global_var():
assert isinstance(tvar, tvm.ir.GlobalVar)
def test_tir_var():
nodes = [
{"type_key": ""},
{"type_key": "Variable",
"attrs": {"dtype": "int32", "name": "x"}},
{"type_key": "SizeVar",
"attrs": {"dtype": "int32", "name": "y"}},
]
data = {
"root" : 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
x = tvm.ir.load_json(json.dumps(data))
assert isinstance(x, tvm.tir.Var)
assert x.name == "x"
data["root"] = 2
y = tvm.ir.load_json(json.dumps(data))
assert isinstance(y, tvm.tir.SizeVar)
assert y.name == "y"
if __name__ == "__main__":
test_type_var()
test_incomplete_type()
test_func_tuple_type()
test_global_var()
test_tir_var()
......@@ -265,7 +265,18 @@ def test_prim_func():
assert func.attrs is None
def test_vars():
x = tvm.tir.Var("xyz", "int8")
assert x.dtype == "int8"
ptype = tvm.ir.PointerType(tvm.ir.PrimType("float"))
x = tvm.tir.Var("xyz", ptype)
assert x.dtype == "handle"
assert x.type_annotation == ptype
assert isinstance(ptype.element_type, tvm.ir.PrimType)
if __name__ == "__main__":
test_vars()
test_prim_func()
test_cast()
test_attr()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment