Unverified Commit 841725cc by Tianqi Chen Committed by GitHub

[TIR][TARGET] Refactor Target codegen to use IRModule and PrimFunc. (#5107)

As part of the unified IR refactor.
This PR refactors the target codegen to use IRModule containing tir::PrimFuncs.

In order to break the refactor into several steps without breaking the codebase,
we built an conversion pass to convert Array<LoweredFunc> into IRModule.

The follow-up refactors will gradually move the passes covered by IRModule up
until we cover all the passes. Then we can remove the additional redundant
concepts such as LoweredFunc.
parent 86079479
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <string> #include <string>
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include <type_traits>
namespace tvm { namespace tvm {
...@@ -308,6 +309,17 @@ class Integer : public IntImm { ...@@ -308,6 +309,17 @@ class Integer : public IntImm {
*/ */
Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*) Integer(IntImm other) : IntImm(std::move(other)) {} // NOLINT(*)
/*! /*!
* \brief Constructor from enum
* \tparam Enum The enum type.
* \param value The enum value.
*/
template<typename ENum,
typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
explicit Integer(ENum value) : Integer(static_cast<int>(value)) {
static_assert(std::is_same<int, typename std::underlying_type<ENum>::type>::value,
"declare enum to be enum int to use visitor");
}
/*!
* \brief Assign an expression to integer. * \brief Assign an expression to integer.
* \param other another expression. * \param other another expression.
*/ */
......
...@@ -213,6 +213,27 @@ constexpr const char* kCallingConv = "calling_conv"; ...@@ -213,6 +213,27 @@ constexpr const char* kCallingConv = "calling_conv";
* \sa tvm::Target * \sa tvm::Target
*/ */
constexpr const char* kTarget = "target"; constexpr const char* kTarget = "target";
/*!
* \brief Global linker symbol of the function in generated code.
*
* This option forces the code generator to name the
* function with the given.
*
* For example, we could set a global_symbol of a function
* early to make sure that we can always refer to it by
* the symbol name in the generated DLL.
*
* We should not set the attribute for local functions,
* so that the compiler can freely rename them.
*
* A unique global symbol will be automatically assigned
* to each function in the module before the target code
* generation phase.
*
* Type: String
*/
constexpr const char* kGlobalSymbol = "global_symbol";
} // namespace attr } // namespace attr
} // namespace tvm } // namespace tvm
#endif // TVM_IR_FUNCTION_H_ #endif // TVM_IR_FUNCTION_H_
...@@ -114,7 +114,8 @@ class PrimTypeNode : public TypeNode { ...@@ -114,7 +114,8 @@ class PrimTypeNode : public TypeNode {
TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode); TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode);
}; };
/*!
/*
* \brief Managed reference to PrimTypeNode. * \brief Managed reference to PrimTypeNode.
* \sa PrimTypeNode * \sa PrimTypeNode
*/ */
...@@ -124,11 +125,53 @@ class PrimType : public Type { ...@@ -124,11 +125,53 @@ class PrimType : public Type {
* \brief Constructor * \brief Constructor
* \param dtype The corresponding dtype. * \param dtype The corresponding dtype.
*/ */
TVM_DLL PrimType(runtime::DataType dtype); TVM_DLL explicit PrimType(runtime::DataType dtype);
TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode); TVM_DEFINE_OBJECT_REF_METHODS(PrimType, Type, PrimTypeNode);
}; };
/*!
* \brief Low-level raw pointer type.
*
* PointerType represents type hints in the TIR to be
* passed to the final code generator.
*
* PointerType should not occur in the high-level analysis.
*
* \sa PointerType
*/
class PointerTypeNode : public TypeNode {
public:
/*!
* \brief The type of the element which the pointer points to.
*/
Type element_type;
void VisitAttrs(AttrVisitor* v) {
v->Visit("element_type", &element_type);
}
static constexpr const char* _type_key = "PointerType";
TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode);
};
/*
* \brief Managed reference to PointerTypeNode.
* \sa PointerTypeNode
*/
class PointerType : public Type {
public:
/*!
* \brief Constructor
* \param element_type The type of the element which the pointer points to.
*/
TVM_DLL explicit PointerType(Type element_type);
TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode);
};
/*! \brief Possible kinds of TypeVars. */ /*! \brief Possible kinds of TypeVars. */
enum TypeKind : int { enum TypeKind : int {
kType = 0, kType = 0,
...@@ -284,6 +327,15 @@ inline Type VoidType() { ...@@ -284,6 +327,15 @@ inline Type VoidType() {
} }
/*! /*!
* \brief Check whether the tyep represents void.
* \return The check result.
*/
inline bool IsVoidType(const Type& type) {
auto* n = type.as<TupleTypeNode>();
return n && n->fields.size() == 0;
}
/*!
* \brief Potential Constraints in a function. * \brief Potential Constraints in a function.
* \sa TypeConstraint * \sa TypeConstraint
*/ */
......
...@@ -55,22 +55,27 @@ namespace tir { ...@@ -55,22 +55,27 @@ namespace tir {
*/ */
class VarNode : public PrimExprNode { class VarNode : public PrimExprNode {
public: public:
/*! \brief constructor */
VarNode() {}
VarNode(DataType dtype, std::string name_hint);
/*! /*!
* \brief The hint to the variable name. * \brief The hint to the variable name.
* \note Each variable is uniquely identified by its address. * \note Each variable is uniquely identified by its address.
*/ */
std::string name_hint; std::string name_hint;
/*!
* \brief type annotaion of the variable.
*
* It is an optional field that provides a refined type of the variable than dtype.
*
* \sa tvm/ir/type.h for discussion of relations between runtime::DataType and Type.
*/
Type type_annotation;
void VisitAttrs(AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
v->Visit("name", &name_hint); v->Visit("name", &name_hint);
v->Visit("type_annotation", &type_annotation);
} }
static constexpr const char* _type_key = "Variable"; static constexpr const char* _type_key = "tir.Var";
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode); TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
}; };
...@@ -78,20 +83,25 @@ class VarNode : public PrimExprNode { ...@@ -78,20 +83,25 @@ class VarNode : public PrimExprNode {
class Var : public PrimExpr { class Var : public PrimExpr {
public: public:
explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {} explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
/*! \brief constructor /*!
* \brief Constructor
* \param name_hint variable name * \param name_hint variable name
* \param t data type * \param dtype data type
*/ */
TVM_DLL explicit Var(std::string name_hint = "v", TVM_DLL explicit Var(std::string name_hint = "v",
DataType t = DataType::Int(32)); DataType dtype = DataType::Int(32));
/*!
* \brief Constructor which provides a more detailed type annotation.
* \param name_hint variable name.
* \param type_annotation The type annotation.
*/
TVM_DLL explicit Var(std::string name_hint, Type type_annotation);
/*! /*!
* \brief Make a new copy of var with same type, append suffix * \brief Make a new copy of var with same type, append suffix
* \param suffix The suffix to be appended. * \param suffix The suffix to be appended.
* \return the new Var copy * \return the new Var copy
*/ */
Var copy_with_suffix(const std::string& suffix) const { TVM_DLL Var copy_with_suffix(const std::string& suffix) const;
return Var((*this)->name_hint + suffix, (*this)->dtype);
}
/*! /*!
* \brief Get pointer to the internal value. * \brief Get pointer to the internal value.
* \return the corresponding Variable. * \return the corresponding Variable.
...@@ -116,15 +126,7 @@ class Var : public PrimExpr { ...@@ -116,15 +126,7 @@ class Var : public PrimExpr {
*/ */
class SizeVarNode : public VarNode { class SizeVarNode : public VarNode {
public: public:
/*! \brief constructor */ static constexpr const char* _type_key = "tir.SizeVar";
SizeVarNode() {}
/*! \brief constructor
* \param dtype data type
* \param name_hint variable name
*/
SizeVarNode(DataType dtype, std::string name_hint);
static constexpr const char* _type_key = "SizeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode); TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode);
}; };
...@@ -132,12 +134,13 @@ class SizeVarNode : public VarNode { ...@@ -132,12 +134,13 @@ class SizeVarNode : public VarNode {
class SizeVar : public Var { class SizeVar : public Var {
public: public:
explicit SizeVar(ObjectPtr<Object> n) : Var(n) {} explicit SizeVar(ObjectPtr<Object> n) : Var(n) {}
/*! \brief constructor /*!
* \brief constructor
* \param name_hint variable name * \param name_hint variable name
* \param t data type * \param t data type
*/ */
TVM_DLL explicit SizeVar(std::string name_hint = "s", TVM_DLL explicit SizeVar(std::string name_hint = "s",
DataType t = DataType::Int(32)); DataType t = DataType::Int(32));
/*! /*!
* \brief Get pointer to the internal value. * \brief Get pointer to the internal value.
* \return the corresponding Variable. * \return the corresponding Variable.
......
...@@ -171,6 +171,16 @@ constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis"; ...@@ -171,6 +171,16 @@ constexpr const char* kDeviceThreadAxis = "tir.device_thread_axis";
* Type: Integer * Type: Integer
*/ */
constexpr const char* kNoAlias = "tir.noalias"; constexpr const char* kNoAlias = "tir.noalias";
/*!
* \brief Mark the function as the entry function of
* the final generated runtime module.
*
* Type: Integer
*
* \note There can only be one entry function per module.
*/
constexpr const char* kIsEntryFunc = "tir.is_entry_func";
} // namespace attr } // namespace attr
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <tvm/te/schedule.h> #include <tvm/te/schedule.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/buffer.h> #include <tvm/tir/buffer.h>
#include <tvm/tir/function.h>
#include <tvm/tir/lowered_func.h> #include <tvm/tir/lowered_func.h>
#include <unordered_map> #include <unordered_map>
...@@ -515,6 +516,19 @@ LoweredFunc CombineContextCall(LoweredFunc f); ...@@ -515,6 +516,19 @@ LoweredFunc CombineContextCall(LoweredFunc f);
*/ */
LoweredFunc PointerValueTypeRewrite(LoweredFunc f); LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
* to avoid pointer casting in backend when possible.
*
* \note implemeneted in storage_rewrite.cc
* \param f The function to be trasnformed
* \return Transformed function.
*/
PrimFunc PointerValueTypeRewrite(PrimFunc f);
/*! /*!
* \brief Lower attached storage access information on device. * \brief Lower attached storage access information on device.
* Do this pass after all storage access analysis finish. * Do this pass after all storage access analysis finish.
......
...@@ -52,11 +52,24 @@ namespace tvm { ...@@ -52,11 +52,24 @@ namespace tvm {
* This function could return a more refined type than * This function could return a more refined type than
* the runtime type provided by expr->dtype * the runtime type provided by expr->dtype
* *
* \param expr The input parameter.
* \return The result type.
*
* \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType. * \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
*/ */
TVM_DLL Type GetType(const PrimExpr& expr); TVM_DLL Type GetType(const PrimExpr& expr);
/*! /*!
* \brief Get the implied DataType for storing values with type during runtime.
*
* \param type The input type.
* \return The result runtime::DataType.
*
* \sa tvm/ir/type.h for discussion about the relation between Type and runtime::DataType.
*/
TVM_DLL runtime::DataType GetRuntimeDataType(const Type& type);
/*!
* Query the maximum possible value of dtype. * Query the maximum possible value of dtype.
* \param dtype The data type. * \param dtype The data type.
* \return the maximum possible value in this format. * \return the maximum possible value in this format.
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# pylint: disable=unused-import # pylint: disable=unused-import
"""Common data structures across all IR variants.""" """Common data structures across all IR variants."""
from .base import SourceName, Span, Node, EnvFunc, load_json, save_json from .base import SourceName, Span, Node, EnvFunc, load_json, save_json
from .type import Type, TypeKind, PrimType, TypeVar, GlobalTypeVar, TupleType from .type import Type, TypeKind, PrimType, PointerType, TypeVar, GlobalTypeVar, TupleType
from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation from .type_relation import TypeCall, TypeRelation
......
...@@ -72,7 +72,15 @@ def create_updater_06_to_07(): ...@@ -72,7 +72,15 @@ def create_updater_06_to_07():
return item return item
return _convert return _convert
def _update_tir_var(new_name):
def _convert(item, _):
item["type_key"] = new_name
item["attrs"]["type_annotation"] = "0"
return item
return _convert
node_map = { node_map = {
# Base IR
"relay.TypeVar": _ftype_var, "relay.TypeVar": _ftype_var,
"relay.GlobalTypeVar": _ftype_var, "relay.GlobalTypeVar": _ftype_var,
"relay.Type": _rename("Type"), "relay.Type": _rename("Type"),
...@@ -91,6 +99,9 @@ def create_updater_06_to_07(): ...@@ -91,6 +99,9 @@ def create_updater_06_to_07():
"relay.PassContext": _rename("transform.PassContext"), "relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"), "relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequantial": _rename("transform.Sequantial"), "relay.Sequantial": _rename("transform.Sequantial"),
# TIR
"Variable": _update_tir_var("tir.Var"),
"SizeVar": _update_tir_var("tir.SizeVar"),
} }
return create_updater(node_map, "0.6", "0.7") return create_updater(node_map, "0.6", "0.7")
......
...@@ -46,6 +46,7 @@ class TypeKind(IntEnum): ...@@ -46,6 +46,7 @@ class TypeKind(IntEnum):
TypeData = 6 TypeData = 6
@tvm._ffi.register_object("PrimType")
class PrimType(Type): class PrimType(Type):
"""Primitive data type in the low level IR """Primitive data type in the low level IR
...@@ -59,6 +60,20 @@ class PrimType(Type): ...@@ -59,6 +60,20 @@ class PrimType(Type):
_ffi_api.PrimType, dtype) _ffi_api.PrimType, dtype)
@tvm._ffi.register_object("PointerType")
class PointerType(Type):
"""PointerType used in the low-level TIR.
Parameters
----------
element_type : tvm.ir.Type
The type of pointer's element.
"""
def __init__(self, element_type):
self.__init_handle_by_constructor__(
_ffi_api.PointerType, element_type)
@tvm._ffi.register_object("TypeVar") @tvm._ffi.register_object("TypeVar")
class TypeVar(Type): class TypeVar(Type):
"""Type parameter in functions. """Type parameter in functions.
......
...@@ -288,7 +288,7 @@ class CmpExpr(PrimExprWithOp): ...@@ -288,7 +288,7 @@ class CmpExpr(PrimExprWithOp):
class LogicalExpr(PrimExprWithOp): class LogicalExpr(PrimExprWithOp):
pass pass
@tvm._ffi.register_object("Variable") @tvm._ffi.register_object("tir.Var")
class Var(PrimExprWithOp): class Var(PrimExprWithOp):
"""Symbolic variable. """Symbolic variable.
...@@ -297,7 +297,7 @@ class Var(PrimExprWithOp): ...@@ -297,7 +297,7 @@ class Var(PrimExprWithOp):
name : str name : str
The name The name
dtype : str dtype : Union[str, tvm.irType]
The data type The data type
""" """
def __init__(self, name, dtype): def __init__(self, name, dtype):
...@@ -305,7 +305,7 @@ class Var(PrimExprWithOp): ...@@ -305,7 +305,7 @@ class Var(PrimExprWithOp):
_ffi_api.Var, name, dtype) _ffi_api.Var, name, dtype)
@tvm._ffi.register_object @tvm._ffi.register_object("tir.SizeVar")
class SizeVar(Var): class SizeVar(Var):
"""Symbolic variable to represent a tensor index size """Symbolic variable to represent a tensor index size
which is greater or equal to zero. which is greater or equal to zero.
......
...@@ -68,7 +68,7 @@ class FuncTouchedDomain final : public StmtExprVisitor { ...@@ -68,7 +68,7 @@ class FuncTouchedDomain final : public StmtExprVisitor {
/* TODO: Thread extent unitest not generated.*/ /* TODO: Thread extent unitest not generated.*/
void VisitStmt_(const AttrStmtNode* op) final { void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) { if (op->attr_key == tir::attr::thread_extent) {
const IterVarNode* thread_axis = op->node.as<IterVarNode>(); const IterVarNode* thread_axis = op->node.as<IterVarNode>();
CHECK(thread_axis); CHECK(thread_axis);
const VarNode* var = thread_axis->var.get(); const VarNode* var = thread_axis->var.get();
......
...@@ -92,8 +92,8 @@ VisitStmt_(const IfThenElseNode* op) { ...@@ -92,8 +92,8 @@ VisitStmt_(const IfThenElseNode* op) {
Stmt IRMutatorWithAnalyzer:: Stmt IRMutatorWithAnalyzer::
VisitStmt_(const AttrStmtNode* op) { VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent || if (op->attr_key == tir::attr::thread_extent ||
op->attr_key == attr::virtual_thread) { op->attr_key == tir::attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U); CHECK_NE(iv->thread_tag.length(), 0U);
analyzer_->Bind(iv->var, analyzer_->Bind(iv->var,
......
...@@ -40,7 +40,7 @@ using runtime::PackedFunc; ...@@ -40,7 +40,7 @@ using runtime::PackedFunc;
using tir::LoweredFunc; using tir::LoweredFunc;
bool LLVMEnabled() { bool LLVMEnabled() {
const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.build_llvm"); const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm");
return pf != nullptr; return pf != nullptr;
} }
......
...@@ -45,6 +45,27 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -45,6 +45,27 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
}); });
PointerType::PointerType(Type element_type) {
ObjectPtr<PointerTypeNode> n = make_object<PointerTypeNode>();
n->element_type = std::move(element_type);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(PointerTypeNode);
TVM_REGISTER_GLOBAL("ir.PointerType")
.set_body_typed([](Type element_type) {
return PointerType(element_type);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PointerTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PointerTypeNode*>(ref.get());
p->Print(node->element_type);
p->stream << '*';
});
TypeVar::TypeVar(std::string name, TypeKind kind) { TypeVar::TypeVar(std::string name, TypeKind kind) {
ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>(); ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
n->name_hint = std::move(name); n->name_hint = std::move(name);
......
...@@ -139,7 +139,7 @@ bool RuntimeEnabled(const std::string& target) { ...@@ -139,7 +139,7 @@ bool RuntimeEnabled(const std::string& target) {
} else if (target == "vulkan") { } else if (target == "vulkan") {
f_name = "device_api.vulkan"; f_name = "device_api.vulkan";
} else if (target == "stackvm") { } else if (target == "stackvm") {
f_name = "codegen.build_stackvm"; f_name = "target.build.stackvm";
} else if (target == "rpc") { } else if (target == "rpc") {
f_name = "device_api.rpc"; f_name = "device_api.rpc";
} else if (target == "micro_dev") { } else if (target == "micro_dev") {
......
...@@ -26,6 +26,9 @@ ...@@ -26,6 +26,9 @@
#include <tvm/target/codegen.h> #include <tvm/target/codegen.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/ir/module.h>
#include <tvm/tir/function.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h> #include <tvm/tir/stmt.h>
#include <tvm/tir/lowered_func.h> #include <tvm/tir/lowered_func.h>
...@@ -51,6 +54,31 @@ ExtractFuncInfo(const Array<tir::LoweredFunc>& funcs) { ...@@ -51,6 +54,31 @@ ExtractFuncInfo(const Array<tir::LoweredFunc>& funcs) {
} }
return fmap; return fmap;
} }
inline std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule& mod) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<tir::PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<tir::PrimFunc>(kv.second);
runtime::FunctionInfo info;
for (size_t i = 0; i < f->params.size(); ++i) {
info.arg_types.push_back(f->params[i].dtype());
}
auto thread_axis = f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis);
if (thread_axis.defined()) {
for (size_t i = 0; i < thread_axis.size(); ++i) {
info.thread_axis_tags.push_back(thread_axis[i]->thread_tag);
}
}
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
fmap[static_cast<std::string>(global_symbol)] = info;
}
return fmap;
}
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_TARGET_BUILD_COMMON_H_ #endif // TVM_TARGET_BUILD_COMMON_H_
...@@ -23,7 +23,12 @@ ...@@ -23,7 +23,12 @@
*/ */
#include <tvm/target/codegen.h> #include <tvm/target/codegen.h>
#include <tvm/target/target.h> #include <tvm/target/target.h>
#include <tvm/ir/module.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/tir/function.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
...@@ -37,6 +42,63 @@ ...@@ -37,6 +42,63 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
// The new build function.
// adapt the old function to the new one
runtime::Module BuildForIRModule(const IRModule& module,
const Target& target) {
std::string build_f_name = "target.build." + target->target_name;
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
CHECK(bf != nullptr)
<< "target.build." << target << " is not enabled";
return (*bf)(module, target->str());
}
// convert legacy LoweredFunc to PrimFunc.
tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) {
// remap args to attach type annotations.
Array<tir::Var> args;
Map<tir::Var, PrimExpr> remap_vars;
for (auto var : from->args) {
if (from->handle_data_type.count(var)) {
tir::Var new_var(var->name_hint,
PointerType(PrimType(var->dtype)));
args.push_back(new_var);
remap_vars.Set(var, new_var);
} else {
args.push_back(var);
}
}
tir::PrimFunc func(args, Substitute(from->body, remap_vars));
func = WithAttr(std::move(func), attr::kGlobalSymbol, runtime::String(from->name));
func = WithAttr(std::move(func), tir::attr::kDeviceThreadAxis, from->thread_axis);
if (from->func_type == tir::LoweredFuncType::kDeviceFunc) {
func = WithAttr(std::move(func),
attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch));
}
if (from->is_restricted) {
func = WithAttr(std::move(func), tir::attr::kNoAlias, Integer(1));
}
return func;
}
IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs) {
Map<GlobalVar, BaseFunc> functions;
for (size_t i = 0; i < funcs.size(); ++i) {
auto f = funcs[i];
tir::PrimFunc pf = ToPrimFunc(f);
if (i == 0) {
pf = WithAttr(std::move(pf), tir::attr::kIsEntryFunc, Integer(1));
}
functions.Set(GlobalVar(f->name), pf);
}
return IRModule(functions);
}
runtime::Module Build(const Array<tir::LoweredFunc>& funcs, runtime::Module Build(const Array<tir::LoweredFunc>& funcs,
const std::string& target) { const std::string& target) {
std::string mode = target; std::string mode = target;
...@@ -51,15 +113,10 @@ runtime::Module Build(const Array<tir::LoweredFunc>& funcs, ...@@ -51,15 +113,10 @@ runtime::Module Build(const Array<tir::LoweredFunc>& funcs,
transformed_funcs.push_back(func); transformed_funcs.push_back(func);
} }
} }
std::string build_f_name = "codegen.build_" + mode;
// the build function. return BuildForIRModule(
const PackedFunc* bf = runtime::Registry::Get(build_f_name); transformed_funcs.size() != 0 ? ToIRModule(transformed_funcs) : ToIRModule(funcs),
CHECK(bf != nullptr) Target::Create(target));
<< "Target " << target << " is not enabled";
runtime::Module m = transformed_funcs.empty() ?
(*bf)(funcs, target) :
(*bf)(transformed_funcs, target);
return m;
} }
/*! \brief Helper class to serialize module */ /*! \brief Helper class to serialize module */
......
...@@ -59,7 +59,7 @@ static inline int DetectROCMmaxThreadsPerBlock() { ...@@ -59,7 +59,7 @@ static inline int DetectROCMmaxThreadsPerBlock() {
// AMDGPU code generator. // AMDGPU code generator.
class CodeGenAMDGPU : public CodeGenLLVM { class CodeGenAMDGPU : public CodeGenLLVM {
public: public:
void AddFunction(const LoweredFunc& f) final { void AddFunction(const PrimFunc& f) final {
// add function as void return value // add function as void return value
CodeGenLLVM::AddFunctionInternal(f, true); CodeGenLLVM::AddFunctionInternal(f, true);
function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL);
...@@ -91,7 +91,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { ...@@ -91,7 +91,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
// TODO(tqchen): for higher version of LLVM, local address space can be set. // TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() { llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
return builder_->CreateAlloca( return builder_->CreateAlloca(
LLVMType(op->dtype), ConstInt32(constant_size)); DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
}); });
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) { if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100 #if TVM_LLVM_VERSION >= 100
...@@ -106,7 +106,8 @@ class CodeGenAMDGPU : public CodeGenLLVM { ...@@ -106,7 +106,8 @@ class CodeGenAMDGPU : public CodeGenLLVM {
<< "Can only allocate shared or local memory inside kernel"; << "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3 // Shared memory: address space == 3
const unsigned shared_address_space = 3; const unsigned shared_address_space = 3;
llvm::Type* type = llvm::ArrayType::get(LLVMType(op->dtype), constant_size); llvm::Type* type = llvm::ArrayType::get(
DTypeToLLVMType(op->dtype), constant_size);
// Allocate shared memory in global, address_space = 3 // Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable( llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
...@@ -120,7 +121,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { ...@@ -120,7 +121,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
} }
} }
buf = builder_->CreatePointerCast( buf = builder_->CreatePointerCast(
buf, LLVMType(op->dtype)->getPointerTo( buf, DTypeToLLVMType(op->dtype)->getPointerTo(
buf->getType()->getPointerAddressSpace())); buf->getType()->getPointerAddressSpace()));
CHECK(!var_map_.count(op->buffer_var.get())); CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf; var_map_[op->buffer_var.get()] = buf;
...@@ -170,7 +171,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { ...@@ -170,7 +171,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
// Additional optimization hook to tweak the builder. // Additional optimization hook to tweak the builder.
} }
unsigned GetGlobalAddressSpace() { unsigned GetGlobalAddressSpace() const final {
return 1; return 1;
} }
...@@ -205,7 +206,7 @@ inline int DetectROCMComputeVersion(const std::string& target) { ...@@ -205,7 +206,7 @@ inline int DetectROCMComputeVersion(const std::string& target) {
return 900; return 900;
} }
runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { runtime::Module BuildAMDGPU(IRModule mod, std::string target) {
#if TVM_LLVM_VERSION < 90 #if TVM_LLVM_VERSION < 90
LOG(FATAL) << "AMDGPU backend requires at least LLVM 9"; LOG(FATAL) << "AMDGPU backend requires at least LLVM 9";
// Lower versions will crash when loading the bitcode, see // Lower versions will crash when loading the bitcode, see
...@@ -222,8 +223,13 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { ...@@ -222,8 +223,13 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str()); std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU()); std::unique_ptr<CodeGenAMDGPU> cg(new CodeGenAMDGPU());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext()); std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
cg->Init(funcs[0]->name, tm.get(), ctx.get(), false, false);
for (LoweredFunc f : funcs) { cg->Init("TVMAMDGPUModule", tm.get(), ctx.get(), false, false);
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<PrimFunc>(kv.second);
cg->AddFunction(f); cg->AddFunction(f);
} }
...@@ -306,10 +312,10 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { ...@@ -306,10 +312,10 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
std::string hsaco = (*f)(arr); std::string hsaco = (*f)(arr);
std::string ll(data_ll.begin(), data_ll.end()); std::string ll(data_ll.begin(), data_ll.end());
return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll, assembly); return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(mod), ll, assembly);
} }
TVM_REGISTER_GLOBAL("codegen.build_rocm") TVM_REGISTER_GLOBAL("target.build.rocm")
.set_body_typed(BuildAMDGPU); .set_body_typed(BuildAMDGPU);
} // namespace codegen } // namespace codegen
......
...@@ -122,11 +122,15 @@ void CodeGenCPU::Init(const std::string& module_name, ...@@ -122,11 +122,15 @@ void CodeGenCPU::Init(const std::string& module_name,
this->InitGlobalContext(dynamic_lookup); this->InitGlobalContext(dynamic_lookup);
} }
void CodeGenCPU::AddFunction(const LoweredFunc& f) { void CodeGenCPU::AddFunction(const PrimFunc& f) {
CodeGenLLVM::AddFunction(f); CodeGenLLVM::AddFunction(f);
if (f_tvm_register_system_symbol_ != nullptr) { if (f_tvm_register_system_symbol_ != nullptr) {
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
export_system_symbols_.emplace_back( export_system_symbols_.emplace_back(
std::make_pair(f->name, builder_->CreatePointerCast(function_, t_void_p_))); std::make_pair(global_symbol.operator std::string(),
builder_->CreatePointerCast(function_, t_void_p_)));
} }
AddDebugInformation(function_); AddDebugInformation(function_);
} }
...@@ -328,7 +332,7 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { ...@@ -328,7 +332,7 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) {
arg_types.push_back(v->getType()); arg_types.push_back(v->getType());
} }
llvm::FunctionType* ftype = llvm::FunctionType::get( llvm::FunctionType* ftype = llvm::FunctionType::get(
LLVMType(op->dtype), arg_types, false); GetLLVMType(GetRef<PrimExpr>(op)), arg_types, false);
// Check if it is available in global function table as injected function. // Check if it is available in global function table as injected function.
auto it = gv_func_map_.find(op->name); auto it = gv_func_map_.find(op->name);
if (it != gv_func_map_.end()) { if (it != gv_func_map_.end()) {
...@@ -693,8 +697,8 @@ CodeGenCPU::MakeCallPacked(const Array<PrimExpr> &args, llvm::Value **rvalue, ...@@ -693,8 +697,8 @@ CodeGenCPU::MakeCallPacked(const Array<PrimExpr> &args, llvm::Value **rvalue,
ret_value, *ret_tcode})); ret_value, *ret_tcode}));
DataType r_api_type = tir::APIType(r_type); DataType r_api_type = tir::APIType(r_type);
*rvalue = builder_->CreateAlignedLoad( *rvalue = builder_->CreateAlignedLoad(
builder_->CreatePointerCast(ret_value, builder_->CreatePointerCast(
LLVMType(r_api_type)->getPointerTo()), ret_value, DTypeToLLVMType(r_api_type)->getPointerTo()),
8); 8);
*rvalue = CreateCast(r_api_type, r_type, *rvalue); *rvalue = CreateCast(r_api_type, r_type, *rvalue);
return end_block; return end_block;
...@@ -873,7 +877,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { ...@@ -873,7 +877,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
this->CreateStaticInit(op->value.as<StringImmNode>()->value, op->body); this->CreateStaticInit(op->value.as<StringImmNode>()->value, op->body);
} else if (op->attr_key == tir::attr::compute_scope) { } else if (op->attr_key == tir::attr::compute_scope) {
this->CreateComputeScope(op); this->CreateComputeScope(op);
} else if (attr::IsPragmaKey(op->attr_key)) { } else if (tir::attr::IsPragmaKey(op->attr_key)) {
if (op->attr_key == "pragma_parallel_stride_pattern") { if (op->attr_key == "pragma_parallel_stride_pattern") {
CHECK(parallel_env_.penv != nullptr) CHECK(parallel_env_.penv != nullptr)
<< "Pragma parallel_stride_pattern only valid in parallel launch"; << "Pragma parallel_stride_pattern only valid in parallel launch";
......
...@@ -42,7 +42,7 @@ class CodeGenCPU : public CodeGenLLVM { ...@@ -42,7 +42,7 @@ class CodeGenCPU : public CodeGenLLVM {
llvm::LLVMContext* ctx, llvm::LLVMContext* ctx,
bool system_lib, bool system_lib,
bool dynamic_lookup) override; bool dynamic_lookup) override;
void AddFunction(const LoweredFunc& f) override; void AddFunction(const PrimFunc& f) override;
void AddMainFunction(const std::string& entry_func_name) override; void AddMainFunction(const std::string& entry_func_name) override;
std::unique_ptr<llvm::Module> Finish() override; std::unique_ptr<llvm::Module> Finish() override;
void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const AssertStmtNode* op) override;
......
...@@ -25,12 +25,17 @@ ...@@ -25,12 +25,17 @@
#define TVM_TARGET_LLVM_CODEGEN_LLVM_H_ #define TVM_TARGET_LLVM_CODEGEN_LLVM_H_
#ifdef TVM_LLVM_VERSION #ifdef TVM_LLVM_VERSION
#include <tvm/ir/module.h>
#include <tvm/runtime/container.h>
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h> #include <tvm/tir/stmt.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/target/codegen.h> #include <tvm/target/codegen.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -78,7 +83,7 @@ class CodeGenLLVM : ...@@ -78,7 +83,7 @@ class CodeGenLLVM :
* \brief Compile and add function f to the current module. * \brief Compile and add function f to the current module.
* \param f The function to be added. * \param f The function to be added.
*/ */
virtual void AddFunction(const LoweredFunc& f); virtual void AddFunction(const PrimFunc& f);
/*! /*!
* \brief Add main function as the entry name * \brief Add main function as the entry name
* \param entry_func_name The name of entry function to be added. * \param entry_func_name The name of entry function to be added.
...@@ -167,7 +172,7 @@ class CodeGenLLVM : ...@@ -167,7 +172,7 @@ class CodeGenLLVM :
* \return The result. * \return The result.
*/ */
template<typename F> template<typename F>
inline llvm::AllocaInst* WithFunctionEntry(F falloca) { llvm::AllocaInst* WithFunctionEntry(F falloca) {
llvm::BasicBlock* current = builder_->GetInsertBlock(); llvm::BasicBlock* current = builder_->GetInsertBlock();
llvm::BasicBlock* entry = &(function_->getEntryBlock()); llvm::BasicBlock* entry = &(function_->getEntryBlock());
builder_->SetInsertPoint(entry, entry->begin()); builder_->SetInsertPoint(entry, entry->begin());
...@@ -198,18 +203,35 @@ class CodeGenLLVM : ...@@ -198,18 +203,35 @@ class CodeGenLLVM :
// Get the maximim storage align bits of buffer pointer given storage scope. // Get the maximim storage align bits of buffer pointer given storage scope.
virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const; virtual int NativeVectorBits(const runtime::StorageScope& storage_scope) const;
// Get correct address space depending on the backend // Get correct address space depending on the backend
virtual unsigned GetGlobalAddressSpace(); virtual unsigned GetGlobalAddressSpace() const;
void AddFunctionInternal(const PrimFunc& f, bool ret_void);
void AddFunctionInternal(const LoweredFunc& f, bool ret_void);
// Create extern call // Create extern call
llvm::CallInst* CreateCallExtern(llvm::Type* ret, llvm::CallInst* CreateCallExtern(llvm::Type* ret,
const std::string& name, const std::string& name,
const std::vector<llvm::Value*>& value); const std::vector<llvm::Value*>& value);
/*! /*!
* \param t The original type. * \brief Get the LLVM Type for a given runtime type.
* \return LLVM type of t * \param dtype The runtime dtype.
*
* \note Only use this function for dealing with PrimTypes.
* For Call and Var that could have more refined types,
* use GetLLVMType instead.
*
* \return LLVM type of dtype
*/
llvm::Type* DTypeToLLVMType(const DataType& dtype) const;
/*!
* \brief Get the LLVM Type for a given type.
* \param dtype The runtime dtype.
* \param type The corresponding TVM Type.
*/
llvm::Type* GetLLVMType(const Type& type) const;
/*!
* \brief Get the LLVM Type for a given type.
* \param dtype The runtime dtype.
* \param type The corresponding TVM Type.
*/ */
llvm::Type* LLVMType(const DataType& t) const; llvm::Type* GetLLVMType(const PrimExpr& expr) const;
// initialize the function state. // initialize the function state.
void InitFuncState(); void InitFuncState();
// Get alignment given index. // Get alignment given index.
......
...@@ -34,7 +34,7 @@ namespace codegen { ...@@ -34,7 +34,7 @@ namespace codegen {
// NVPTX code generator. // NVPTX code generator.
class CodeGenNVPTX : public CodeGenLLVM { class CodeGenNVPTX : public CodeGenLLVM {
public: public:
void AddFunction(const LoweredFunc& f) final { void AddFunction(const PrimFunc& f) final {
// add function as void return value // add function as void return value
CodeGenLLVM::AddFunctionInternal(f, true); CodeGenLLVM::AddFunctionInternal(f, true);
// annotate as kernel function // annotate as kernel function
...@@ -68,7 +68,7 @@ class CodeGenNVPTX : public CodeGenLLVM { ...@@ -68,7 +68,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
// TODO(tqchen): for higher version of LLVM, local address space can be set. // TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() { llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
return builder_->CreateAlloca( return builder_->CreateAlloca(
LLVMType(op->dtype), ConstInt32(constant_size)); DTypeToLLVMType(op->dtype), ConstInt32(constant_size));
}); });
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) { if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100 #if TVM_LLVM_VERSION >= 100
...@@ -83,7 +83,8 @@ class CodeGenNVPTX : public CodeGenLLVM { ...@@ -83,7 +83,8 @@ class CodeGenNVPTX : public CodeGenLLVM {
<< "Can only allocate shared or local memory inside kernel"; << "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3 // Shared memory: address space == 3
const unsigned shared_address_space = 3; const unsigned shared_address_space = 3;
llvm::Type* type = llvm::ArrayType::get(LLVMType(op->dtype), constant_size); llvm::Type* type = llvm::ArrayType::get(
DTypeToLLVMType(op->dtype), constant_size);
// Allocate shared memory in global, address_space = 3 // Allocate shared memory in global, address_space = 3
llvm::GlobalVariable *global = new llvm::GlobalVariable( llvm::GlobalVariable *global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared",
...@@ -97,7 +98,7 @@ class CodeGenNVPTX : public CodeGenLLVM { ...@@ -97,7 +98,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
} }
} }
buf = builder_->CreatePointerCast( buf = builder_->CreatePointerCast(
buf, LLVMType(op->dtype)->getPointerTo( buf, DTypeToLLVMType(op->dtype)->getPointerTo(
buf->getType()->getPointerAddressSpace())); buf->getType()->getPointerAddressSpace()));
CHECK(!var_map_.count(op->buffer_var.get())); CHECK(!var_map_.count(op->buffer_var.get()));
var_map_[op->buffer_var.get()] = buf; var_map_[op->buffer_var.get()] = buf;
...@@ -190,7 +191,7 @@ inline int DetectCUDAComputeVersion() { ...@@ -190,7 +191,7 @@ inline int DetectCUDAComputeVersion() {
} }
} }
runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) { runtime::Module BuildNVPTX(IRModule mod, std::string target) {
InitializeLLVM(); InitializeLLVM();
CHECK(target.length() >= 5 && CHECK(target.length() >= 5 &&
target.substr(0, 5) == "nvptx"); target.substr(0, 5) == "nvptx");
...@@ -202,8 +203,13 @@ runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) { ...@@ -202,8 +203,13 @@ runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str()); std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX()); std::unique_ptr<CodeGenNVPTX> cg(new CodeGenNVPTX());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext()); std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
cg->Init(funcs[0]->name, tm.get(), ctx.get(), false, false);
for (LoweredFunc f : funcs) { cg->Init("TVMPTXModule", tm.get(), ctx.get(), false, false);
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<PrimFunc>(kv.second);
cg->AddFunction(f); cg->AddFunction(f);
} }
...@@ -249,10 +255,10 @@ runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) { ...@@ -249,10 +255,10 @@ runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
#endif #endif
pass.run(*module); pass.run(*module);
std::string ptx(data_ptx.begin(), data_ptx.end()); std::string ptx(data_ptx.begin(), data_ptx.end());
return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(funcs), ll); return CUDAModuleCreate(ptx, "ptx", ExtractFuncInfo(mod), ll);
} }
TVM_REGISTER_GLOBAL("codegen.build_nvptx") TVM_REGISTER_GLOBAL("target.build.nvptx")
.set_body_typed(BuildNVPTX); .set_body_typed(BuildNVPTX);
} // namespace codegen } // namespace codegen
......
...@@ -88,7 +88,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { ...@@ -88,7 +88,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
if (from.lanes() >= 16 && has_avx512) { if (from.lanes() >= 16 && has_avx512) {
return CallVectorIntrin( return CallVectorIntrin(
::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, ::llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16,
LLVMType(DataType::Float(32, from.lanes())), DTypeToLLVMType(DataType::Float(32, from.lanes())),
{ {
MakeValue(tir::CallNode::make( MakeValue(tir::CallNode::make(
DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value}, DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value},
...@@ -103,7 +103,8 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { ...@@ -103,7 +103,8 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
if (from.lanes() >= 8 && has_f16c) { if (from.lanes() >= 8 && has_f16c) {
return CallVectorIntrin( return CallVectorIntrin(
::llvm::Intrinsic::x86_vcvtph2ps_256, 8, LLVMType(DataType::Float(32, from.lanes())), ::llvm::Intrinsic::x86_vcvtph2ps_256, 8,
DTypeToLLVMType(DataType::Float(32, from.lanes())),
{MakeValue(tir::CallNode::make( {MakeValue(tir::CallNode::make(
DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value}, DataType::Int(16, from.lanes()), tir::CallNode::reinterpret, {op->value},
tir::CallNode::PureIntrinsic))}); tir::CallNode::PureIntrinsic))});
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/ir/module.h>
#include <tvm/target/codegen.h> #include <tvm/target/codegen.h>
#include <mutex> #include <mutex>
#include "llvm_common.h" #include "llvm_common.h"
...@@ -192,21 +193,39 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -192,21 +193,39 @@ class LLVMModuleNode final : public runtime::ModuleNode {
return ""; return "";
} }
void Init(const Array<LoweredFunc>& funcs, std::string target) { void Init(const IRModule& mod, std::string target) {
InitializeLLVM(); InitializeLLVM();
tm_ = GetLLVMTargetMachine(target); tm_ = GetLLVMTargetMachine(target);
bool system_lib = (target.find("-system-lib") != std::string::npos); bool system_lib = (target.find("-system-lib") != std::string::npos);
CHECK_NE(funcs.size(), 0U);
ctx_ = std::make_shared<llvm::LLVMContext>(); ctx_ = std::make_shared<llvm::LLVMContext>();
std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_.get()); std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_.get());
entry_func_ = funcs[0]->name;
cg->Init(funcs[0]->name, tm_.get(), ctx_.get(), system_lib, system_lib); std::vector<PrimFunc> funcs;
for (LoweredFunc f : funcs) { for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
auto f = Downcast<PrimFunc>(kv.second);
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined());
entry_func_ = global_symbol;
}
funcs.push_back(f);
}
CHECK_NE(funcs.size(), 0U);
// TODO(tqchen): remove the entry function behavior as it does not
// makes sense when we start to use multiple modules.
cg->Init("TVMMod", tm_.get(), ctx_.get(), system_lib, system_lib);
for (const auto& f : funcs) {
cg->AddFunction(f); cg->AddFunction(f);
} }
cg->AddMainFunction(funcs[0]->name);
module_ = cg->Finish();
if (entry_func_.length() != 0) {
cg->AddMainFunction(entry_func_);
}
module_ = cg->Finish();
module_->addModuleFlag(llvm::Module::Warning, "tvm_target", llvm::MDString::get(*ctx_, target)); module_->addModuleFlag(llvm::Module::Warning, "tvm_target", llvm::MDString::get(*ctx_, target));
module_->addModuleFlag(llvm::Module::Override, "Debug Info Version", module_->addModuleFlag(llvm::Module::Override, "Debug Info Version",
llvm::DEBUG_METADATA_VERSION); llvm::DEBUG_METADATA_VERSION);
...@@ -349,12 +368,14 @@ unsigned LookupLLVMIntrinsic(const std::string& name) { ...@@ -349,12 +368,14 @@ unsigned LookupLLVMIntrinsic(const std::string& name) {
return llvm::Function::lookupIntrinsicID(name); return llvm::Function::lookupIntrinsicID(name);
} }
TVM_REGISTER_GLOBAL("codegen.build_llvm")
.set_body([](TVMArgs args, TVMRetValue* rv) { TVM_REGISTER_GLOBAL("target.build.llvm")
auto n = make_object<LLVMModuleNode>(); .set_body_typed([](IRModule mod, std::string target) {
n->Init(args[0].operator Array<LoweredFunc>(), args[1].operator std::string()); auto n = make_object<LLVMModuleNode>();
*rv = runtime::Module(n); n->Init(mod, target);
}); return runtime::Module(n);
});
TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
......
...@@ -127,15 +127,23 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) { ...@@ -127,15 +127,23 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) {
return ptx; return ptx;
} }
runtime::Module BuildCUDA(Array<LoweredFunc> funcs) { runtime::Module BuildCUDA(IRModule mod) {
using tvm::runtime::Registry; using tvm::runtime::Registry;
bool output_ssa = false; bool output_ssa = false;
CodeGenCUDA cg; CodeGenCUDA cg;
cg.Init(output_ssa); cg.Init(output_ssa);
for (LoweredFunc f : funcs) { for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenCUDA: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f); cg.AddFunction(f);
} }
std::string code = cg.Finish(); std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) { if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
...@@ -151,10 +159,10 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) { ...@@ -151,10 +159,10 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
} else { } else {
ptx = NVRTCCompile(code, cg.need_include_path()); ptx = NVRTCCompile(code, cg.need_include_path());
} }
return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(funcs), code); return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
} }
TVM_REGISTER_GLOBAL("codegen.build_cuda") TVM_REGISTER_GLOBAL("target.build.cuda")
.set_body_typed(BuildCUDA); .set_body_typed(BuildCUDA);
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -31,16 +31,26 @@ ...@@ -31,16 +31,26 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
runtime::Module BuildAOCL(Array<LoweredFunc> funcs, std::string target_str, runtime::Module BuildAOCL(IRModule mod,
std::string target_str,
bool emulation) { bool emulation) {
// Get code. // Get code.
using tvm::runtime::Registry; using tvm::runtime::Registry;
bool output_ssa = false; bool output_ssa = false;
CodeGenOpenCL cg; CodeGenOpenCL cg;
cg.Init(output_ssa); cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodegenOpenCL: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f); cg.AddFunction(f);
} }
std::string code = cg.Finish(); std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) { if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) {
code = (*f)(code).operator std::string(); code = (*f)(code).operator std::string();
...@@ -68,15 +78,15 @@ runtime::Module BuildAOCL(Array<LoweredFunc> funcs, std::string target_str, ...@@ -68,15 +78,15 @@ runtime::Module BuildAOCL(Array<LoweredFunc> funcs, std::string target_str,
std::string aocxbin; std::string aocxbin;
runtime::LoadBinaryFromFile("aocl.aocx", &aocxbin); runtime::LoadBinaryFromFile("aocl.aocx", &aocxbin);
return AOCLModuleCreate(aocxbin, "aocx", ExtractFuncInfo(funcs), code); return AOCLModuleCreate(aocxbin, "aocx", ExtractFuncInfo(mod), code);
} }
TVM_REGISTER_GLOBAL("codegen.build_aocl") TVM_REGISTER_GLOBAL("target.build.aocl")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildAOCL(args[0], args[1], false); *rv = BuildAOCL(args[0], args[1], false);
}); });
TVM_REGISTER_GLOBAL("codegen.build_aocl_sw_emu") TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildAOCL(args[0], args[1], true); *rv = BuildAOCL(args[0], args[1], true);
}); });
......
...@@ -35,7 +35,7 @@ void CodeGenC::Init(bool output_ssa) { ...@@ -35,7 +35,7 @@ void CodeGenC::Init(bool output_ssa) {
print_ssa_form_ = output_ssa; print_ssa_form_ = output_ssa;
} }
void CodeGenC::InitFuncState(LoweredFunc f) { void CodeGenC::InitFuncState(const PrimFunc& f) {
alloc_storage_scope_.clear(); alloc_storage_scope_.clear();
handle_data_type_.clear(); handle_data_type_.clear();
CodeGenSourceBase::ClearFuncState(); CodeGenSourceBase::ClearFuncState();
...@@ -72,39 +72,46 @@ void CodeGenC::ReserveKeywordsAsUnique() { ...@@ -72,39 +72,46 @@ void CodeGenC::ReserveKeywordsAsUnique() {
GetUniqueName("return"); GetUniqueName("return");
} }
void CodeGenC::AddFunction(LoweredFunc f) { void CodeGenC::AddFunction(const PrimFunc& f) {
// clear previous generated state. // clear previous generated state.
this->InitFuncState(f); this->InitFuncState(f);
// reserve keywords // reserve keywords
ReserveKeywordsAsUnique(); ReserveKeywordsAsUnique();
// add to alloc buffer type.
for (const auto & kv : f->handle_data_type) {
RegisterHandleType(kv.first.get(), kv.second.dtype());
}
this->stream << "void " << f->name << "("; auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
for (size_t i = 0; i < f->args.size(); ++i) { CHECK(global_symbol.defined())
Var v = f->args[i]; << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
this->PrintFuncPrefix();
this->stream << " " << static_cast<std::string>(global_symbol) << "(";
for (size_t i = 0; i < f->params.size(); ++i) {
tir::Var v = f->params[i];
std::string vid = AllocVarID(v.get()); std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", "; if (i != 0) stream << ", ";
if (v.dtype().is_handle()) { if (v.dtype().is_handle()) {
auto it = alloc_storage_scope_.find(v.get()); auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream); PrintStorageScope(it->second, stream);
stream << ' '; stream << ' ';
}
if (handle_data_type_.count(v.get())) { PrintType(GetType(v), stream);
PrintType(handle_data_type_.at(v.get()), stream); // Register handle data type
} else { // TODO(tvm-team): consider simply keep type info in the
stream << "void"; // type annotation(via a normalizing rewriting).
if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(v.get(), prim->dtype);
}
} }
stream << "*";
if (f->is_restricted && restrict_keyword_.length() != 0) { if (no_alias && restrict_keyword_.length() != 0) {
stream << ' ' << restrict_keyword_; stream << ' ' << restrict_keyword_;
} }
} else { } else {
PrintType(v.dtype(), stream); PrintType(GetType(v), stream);
} }
stream << ' ' << vid; stream << ' ' << vid;
} }
...@@ -112,11 +119,19 @@ void CodeGenC::AddFunction(LoweredFunc f) { ...@@ -112,11 +119,19 @@ void CodeGenC::AddFunction(LoweredFunc f) {
this->PreFunctionBody(f); this->PreFunctionBody(f);
int func_scope = this->BeginScope(); int func_scope = this->BeginScope();
this->PrintStmt(f->body); this->PrintStmt(f->body);
this->PrintFinalReturn();
this->EndScope(func_scope); this->EndScope(func_scope);
this->PrintIndent(); this->PrintIndent();
this->stream << "}\n\n"; this->stream << "}\n\n";
} }
void CodeGenC::PrintFuncPrefix() {
stream << "void";
}
void CodeGenC::PrintFinalReturn() {
}
std::string CodeGenC::Finish() { std::string CodeGenC::Finish() {
return decl_stream.str() + stream.str(); return decl_stream.str() + stream.str();
} }
...@@ -275,7 +290,6 @@ std::string CodeGenC::GetStructRef( ...@@ -275,7 +290,6 @@ std::string CodeGenC::GetStructRef(
} }
} }
bool CodeGenC::HandleTypeMatch(const VarNode* buf_var, DataType t) const { bool CodeGenC::HandleTypeMatch(const VarNode* buf_var, DataType t) const {
auto it = handle_data_type_.find(buf_var); auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) return false; if (it == handle_data_type_.end()) return false;
...@@ -370,6 +384,20 @@ void CodeGenC::PrintType(DataType t, std::ostream& os) { // NOLINT(*) ...@@ -370,6 +384,20 @@ void CodeGenC::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
} }
void CodeGenC::PrintType(const Type& type, std::ostream& os) { // NOLINT(*)
if (auto* ptr = type.as<PrimTypeNode>()) {
return PrintType(ptr->dtype, os);
} else if (auto* ptr = type.as<PointerTypeNode>()) {
PrintType(ptr->element_type, os);
os << '*';
} else if (IsVoidType(type)) {
os << "void";
} else {
LOG(FATAL) << "Type " << type << " does not have a corresponding C Type";
}
}
inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) inline void PrintConst(const IntImmNode* op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
if (op->dtype == DataType::Int(32)) { if (op->dtype == DataType::Int(32)) {
std::ostringstream temp; std::ostringstream temp;
......
...@@ -26,9 +26,11 @@ ...@@ -26,9 +26,11 @@
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h> #include <tvm/tir/stmt.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/target/codegen.h> #include <tvm/target/codegen.h>
#include <tvm/tir/lowered_func.h> #include <tvm/tir/lowered_func.h>
#include <tvm/runtime/container.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
...@@ -62,8 +64,9 @@ class CodeGenC : ...@@ -62,8 +64,9 @@ class CodeGenC :
/*! /*!
* \brief Add the function to the generated module. * \brief Add the function to the generated module.
* \param f The function to be compiled. * \param f The function to be compiled.
* \param whether to append return 0 in the end.
*/ */
void AddFunction(LoweredFunc f); void AddFunction(const PrimFunc& f);
/*! /*!
* \brief Finalize the compilation and return the code. * \brief Finalize the compilation and return the code.
* \return The code. * \return The code.
...@@ -93,15 +96,25 @@ class CodeGenC : ...@@ -93,15 +96,25 @@ class CodeGenC :
} }
// The following parts are overloadable print operations. // The following parts are overloadable print operations.
/*! /*!
* \brief Print the function header before the argument list
*
* Example: stream << "void";
*/
virtual void PrintFuncPrefix(); // NOLINT(*)
/*!
* \brief Print the final return at the end the function.
*/
virtual void PrintFinalReturn(); // NOLINT(*)
/*!
* \brief Insert statement before function body. * \brief Insert statement before function body.
* \param f The function to be compiled. * \param f The function to be compiled.
*/ */
virtual void PreFunctionBody(LoweredFunc f) {} virtual void PreFunctionBody(const PrimFunc& f) {}
/*! /*!
* \brief Initialize codegen state for generating f. * \brief Initialize codegen state for generating f.
* \param f The function to be compiled. * \param f The function to be compiled.
*/ */
virtual void InitFuncState(LoweredFunc f); virtual void InitFuncState(const PrimFunc& f);
// expression // expression
void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const VarNode* op, std::ostream& os) override; // NOLINT(*)
void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const LoadNode* op, std::ostream& os) override; // NOLINT(*)
...@@ -149,6 +162,12 @@ class CodeGenC : ...@@ -149,6 +162,12 @@ class CodeGenC :
*/ */
virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*) virtual void PrintType(DataType t, std::ostream& os); // NOLINT(*)
/*! /*!
* Print Type represetnation of type type.
* \param type The type representation.
* \param os The stream to print the ctype into
*/
virtual void PrintType(const Type& type, std::ostream& os); // NOLINT(*)
/*!
* \brief Print expr representing the thread tag * \brief Print expr representing the thread tag
* \param IterVar iv The thread index to be binded; * \param IterVar iv The thread index to be binded;
*/ */
...@@ -223,12 +242,6 @@ class CodeGenC : ...@@ -223,12 +242,6 @@ class CodeGenC :
// override // override
void PrintSSAAssign( void PrintSSAAssign(
const std::string& target, const std::string& src, DataType t) final; const std::string& target, const std::string& src, DataType t) final;
/*! \brief restrict keyword */
std::string restrict_keyword_{""};
/*! \brief the storage scope of allocation */
std::unordered_map<const VarNode*, std::string> alloc_storage_scope_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const VarNode*, DataType> handle_data_type_;
/*! \brief reserves common C keywords */ /*! \brief reserves common C keywords */
void ReserveKeywordsAsUnique(); void ReserveKeywordsAsUnique();
...@@ -237,6 +250,13 @@ class CodeGenC : ...@@ -237,6 +250,13 @@ class CodeGenC :
return volatile_buf_.count(buf_var) != 0; return volatile_buf_.count(buf_var) != 0;
} }
/*! \brief restrict keyword */
std::string restrict_keyword_{""};
/*! \brief the storage scope of allocation */
std::unordered_map<const VarNode*, std::string> alloc_storage_scope_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const VarNode*, DataType> handle_data_type_;
private: private:
/*! \brief whether to print in SSA form */ /*! \brief whether to print in SSA form */
bool print_ssa_form_{false}; bool print_ssa_form_{false};
......
...@@ -41,59 +41,16 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts) { ...@@ -41,59 +41,16 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts) {
CodeGenC::Init(output_ssa); CodeGenC::Init(output_ssa);
} }
void CodeGenCHost::AddFunction(LoweredFunc f) { void CodeGenCHost::PrintFuncPrefix() { // NOLINT(*)
// clear previous generated state. stream << "#ifdef __cplusplus\n"
this->InitFuncState(f); << "extern \"C\"\n"
// reserve keywords << "#endif\n"
ReserveKeywordsAsUnique(); << "TVM_DLL int32_t";
// add to alloc buffer type.
for (const auto & kv : f->handle_data_type) {
RegisterHandleType(kv.first.get(), kv.second.dtype());
}
this->stream << "#ifdef __cplusplus\n";
this->stream << "extern \"C\"\n";
this->stream << "#endif\n";
this->stream << "TVM_DLL int32_t " << f->name << "(";
for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i];
std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", ";
if (v.dtype().is_handle()) {
auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
}
stream << ' ';
if (handle_data_type_.count(v.get())) {
PrintType(handle_data_type_.at(v.get()), stream);
} else {
stream << "void";
}
stream << "*";
if (f->is_restricted && restrict_keyword_.length() != 0) {
stream << ' ' << restrict_keyword_;
}
} else {
PrintType(v.dtype(), stream);
}
stream << ' ' << vid;
}
stream << ") {\n";
this->PreFunctionBody(f);
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
this->PrintIndent();
this->stream << "return 0;\n";
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
} }
std::string CodeGenCHost::Finish() { void CodeGenCHost::PrintFinalReturn() { // NOLINT(*)
return CodeGenC::Finish(); this->PrintIndent();
stream << "return 0;\n";
} }
void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
...@@ -277,20 +234,25 @@ inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, ...@@ -277,20 +234,25 @@ inline void CodeGenCHost::PrintTernaryCondExpr(const T* op,
<< "? (" << a_id << ") : (" << b_id << "))"; << "? (" << a_id << ") : (" << b_id << "))";
} }
runtime::Module BuildCHost(Array<LoweredFunc> funcs) { runtime::Module BuildCHost(IRModule mod) {
using tvm::runtime::Registry; using tvm::runtime::Registry;
bool output_ssa = false; bool output_ssa = false;
bool emit_asserts = false; bool emit_asserts = false;
CodeGenCHost cg; CodeGenCHost cg;
cg.Init(output_ssa, emit_asserts); cg.Init(output_ssa, emit_asserts);
for (LoweredFunc f : funcs) {
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodegenCHost: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
cg.AddFunction(f); cg.AddFunction(f);
} }
std::string code = cg.Finish(); std::string code = cg.Finish();
return CSourceModuleCreate(code, "c"); return CSourceModuleCreate(code, "c");
} }
TVM_REGISTER_GLOBAL("codegen.build_c") TVM_REGISTER_GLOBAL("target.build.c")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildCHost(args[0]); *rv = BuildCHost(args[0]);
}); });
......
...@@ -36,10 +36,10 @@ class CodeGenCHost final : public CodeGenC { ...@@ -36,10 +36,10 @@ class CodeGenCHost final : public CodeGenC {
public: public:
CodeGenCHost(); CodeGenCHost();
void Init(bool output_ssa, bool emit_asserts); void Init(bool output_ssa, bool emit_asserts);
void AddFunction(LoweredFunc f);
std::string Finish();
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
void PrintFuncPrefix() final; // NOLINT(*)
void PrintFinalReturn() final; // NOLINT(*)
// overload visitor functions // overload visitor functions
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
......
...@@ -43,9 +43,9 @@ void CodeGenCUDA::Init(bool output_ssa) { ...@@ -43,9 +43,9 @@ void CodeGenCUDA::Init(bool output_ssa) {
CHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state); CHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state);
} }
void CodeGenCUDA::AddFunction(LoweredFunc f) {
this->stream << "extern \"C\" __global__ "; void CodeGenCUDA::PrintFuncPrefix() {
CodeGenC::AddFunction(f); stream << "extern \"C\" __global__ void";
} }
std::string CodeGenCUDA::Finish() { std::string CodeGenCUDA::Finish() {
...@@ -424,11 +424,11 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { ...@@ -424,11 +424,11 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
} }
void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::fragment_shape) { if (op->attr_key == tir::attr::fragment_shape) {
const VarNode* buffer = op->node.as<VarNode>(); const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* shape_str = op->value.as<StringImmNode>(); const StringImmNode* shape_str = op->value.as<StringImmNode>();
fragment_shapes[buffer] = shape_str->value; fragment_shapes[buffer] = shape_str->value;
} else if (op->attr_key == attr::fragment_layout) { } else if (op->attr_key == tir::attr::fragment_layout) {
const VarNode* buffer = op->node.as<VarNode>(); const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* layout_str = op->value.as<StringImmNode>(); const StringImmNode* layout_str = op->value.as<StringImmNode>();
fragment_layouts[buffer] = layout_str->value; fragment_layouts[buffer] = layout_str->value;
......
...@@ -37,12 +37,12 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -37,12 +37,12 @@ class CodeGenCUDA final : public CodeGenC {
public: public:
CodeGenCUDA(); CodeGenCUDA();
void Init(bool output_ssa); void Init(bool output_ssa);
void AddFunction(LoweredFunc f);
std::string Finish(); std::string Finish();
bool need_include_path() { bool need_include_path() {
return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_); return (enable_fp16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
} }
// override behavior // override behavior
void PrintFuncPrefix() final;
void VisitStmt_(const ForNode* op) final; void VisitStmt_(const ForNode* op) final;
void PrintStorageSync(const CallNode* op) final; void PrintStorageSync(const CallNode* op) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
......
...@@ -31,10 +31,10 @@ ...@@ -31,10 +31,10 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
void CodeGenMetal::InitFuncState(LoweredFunc f) { void CodeGenMetal::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f); CodeGenC::InitFuncState(f);
// analyze the data; // analyze the data;
for (Var arg : f->args) { for (Var arg : f->params) {
if (arg.dtype().is_handle()) { if (arg.dtype().is_handle()) {
alloc_storage_scope_[arg.get()] = "global"; alloc_storage_scope_[arg.get()] = "global";
} }
...@@ -49,48 +49,55 @@ CodeGenMetal::CodeGenMetal() { ...@@ -49,48 +49,55 @@ CodeGenMetal::CodeGenMetal() {
<< "};\n\n"; << "};\n\n";
} }
void CodeGenMetal::AddFunction(LoweredFunc f) { void CodeGenMetal::AddFunction(const PrimFunc& f) {
// clear previous generated state. // clear previous generated state.
this->InitFuncState(f); this->InitFuncState(f);
// skip the first underscore, so SSA variable starts from _1 // skip the first underscore, so SSA variable starts from _1
GetUniqueName("_"); GetUniqueName("_");
// add to alloc buffer type. // add to alloc buffer type.
for (const auto & kv : f->handle_data_type) { auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
RegisterHandleType(kv.first.get(), kv.second.dtype()); CHECK(global_symbol.defined())
} << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
// Function header. // Function header.
this->stream << "kernel void " << f->name << "(\n"; this->stream << "kernel void " << static_cast<std::string>(global_symbol) << "(";
// Buffer arguments // Buffer arguments
size_t num_buffer = 0; size_t num_buffer = 0;
for (size_t i = 0; i < f->args.size(); ++i, ++num_buffer) { for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) {
Var v = f->args[i]; Var v = f->params[i];
if (!v.dtype().is_handle()) break; if (!v.dtype().is_handle()) break;
stream << " "; stream << " ";
std::string vid = AllocVarID(v.get()); std::string vid = AllocVarID(v.get());
auto it = alloc_storage_scope_.find(v.get()); auto it = alloc_storage_scope_.find(v.get());
CHECK(it != alloc_storage_scope_.end()); if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream); PrintStorageScope(it->second, stream);
}
stream << ' '; stream << ' ';
if (handle_data_type_.count(v.get())) { PrintType(GetType(v), stream);
PrintType(handle_data_type_.at(v.get()), stream); // Register handle data type
stream << "*"; // TODO(tvm-team): consider simply keep type info in the
} else { // type annotation(via a normalizing rewriting).
PrintType(v.dtype(), stream); if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(v.get(), prim->dtype);
}
} }
stream << ' ' << vid stream << ' ' << vid
<< " [[ buffer(" << i << ") ]],\n"; << " [[ buffer(" << i << ") ]],\n";
} }
// Setup normal arguments. // Setup normal arguments.
size_t nargs = f->args.size() - num_buffer; size_t nargs = f->params.size() - num_buffer;
std::string varg = GetUniqueName("arg"); std::string varg = GetUniqueName("arg");
if (nargs != 0) { if (nargs != 0) {
std::string arg_buf_type = f->name + "_args_t"; std::string arg_buf_type = static_cast<std::string>(global_symbol) + "_args_t";
stream << " constant " << arg_buf_type << "& " << varg stream << " constant " << arg_buf_type << "& " << varg
<< " [[ buffer(" << num_buffer << ") ]],\n"; << " [[ buffer(" << num_buffer << ") ]],\n";
// declare the struct // declare the struct
decl_stream << "struct " << arg_buf_type << " {\n"; decl_stream << "struct " << arg_buf_type << " {\n";
for (size_t i = num_buffer; i < f->args.size(); ++i) { for (size_t i = num_buffer; i < f->params.size(); ++i) {
Var v = f->args[i]; Var v = f->params[i];
CHECK(!v.dtype().is_handle()); CHECK(!v.dtype().is_handle());
std::string vid = AllocVarID(v.get()); std::string vid = AllocVarID(v.get());
std::ostringstream vref; std::ostringstream vref;
...@@ -113,7 +120,10 @@ void CodeGenMetal::AddFunction(LoweredFunc f) { ...@@ -113,7 +120,10 @@ void CodeGenMetal::AddFunction(LoweredFunc f) {
CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx"); CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx");
CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx"); CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx");
int work_dim = 0; int work_dim = 0;
for (IterVar iv : f->thread_axis) { auto thread_axis = f->GetAttr<Array<tir::IterVar>>(tir::attr::kDeviceThreadAxis);
CHECK(thread_axis.defined());
for (IterVar iv : thread_axis) {
runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag);
work_dim = std::max(work_dim, scope.dim_index + 1); work_dim = std::max(work_dim, scope.dim_index + 1);
} }
...@@ -127,7 +137,7 @@ void CodeGenMetal::AddFunction(LoweredFunc f) { ...@@ -127,7 +137,7 @@ void CodeGenMetal::AddFunction(LoweredFunc f) {
stream << " threadIdx [[thread_position_in_threadgroup]]\n"; stream << " threadIdx [[thread_position_in_threadgroup]]\n";
} }
// bind thread axis // bind thread axis
for (IterVar iv : f->thread_axis) { for (IterVar iv : thread_axis) {
CHECK(!var_idmap_.count(iv->var.get())); CHECK(!var_idmap_.count(iv->var.get()));
std::string vname = iv->thread_tag; std::string vname = iv->thread_tag;
if (work_dim <= 1) { if (work_dim <= 1) {
...@@ -257,14 +267,23 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT ...@@ -257,14 +267,23 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT
} }
} }
runtime::Module BuildMetal(Array<LoweredFunc> funcs) { runtime::Module BuildMetal(IRModule mod) {
using tvm::runtime::Registry; using tvm::runtime::Registry;
bool output_ssa = false; bool output_ssa = false;
CodeGenMetal cg; CodeGenMetal cg;
cg.Init(output_ssa); cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenMetal: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f); cg.AddFunction(f);
} }
std::string code = cg.Finish(); std::string code = cg.Finish();
std::string fmt = "metal"; std::string fmt = "metal";
std::string source = ""; std::string source = "";
...@@ -273,10 +292,10 @@ runtime::Module BuildMetal(Array<LoweredFunc> funcs) { ...@@ -273,10 +292,10 @@ runtime::Module BuildMetal(Array<LoweredFunc> funcs) {
code = (*f)(code).operator std::string(); code = (*f)(code).operator std::string();
fmt = "metallib"; fmt = "metallib";
} }
return MetalModuleCreate(code, fmt, ExtractFuncInfo(funcs), source); return MetalModuleCreate(code, fmt, ExtractFuncInfo(mod), source);
} }
TVM_REGISTER_GLOBAL("codegen.build_metal") TVM_REGISTER_GLOBAL("target.build.metal")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildMetal(args[0]); *rv = BuildMetal(args[0]);
}); });
......
...@@ -34,10 +34,10 @@ namespace codegen { ...@@ -34,10 +34,10 @@ namespace codegen {
class CodeGenMetal final : public CodeGenC { class CodeGenMetal final : public CodeGenC {
public: public:
CodeGenMetal(); CodeGenMetal();
void AddFunction(LoweredFunc f);
// override print thread tag. // override print thread tag.
void PrintArgUnionDecl(); void PrintArgUnionDecl();
void InitFuncState(LoweredFunc f) final; void AddFunction(const PrimFunc& f); // NOLINT(*)
void InitFuncState(const PrimFunc& f) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const CallNode* op) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
...@@ -50,9 +50,10 @@ class CodeGenMetal final : public CodeGenC { ...@@ -50,9 +50,10 @@ class CodeGenMetal final : public CodeGenC {
const std::string& vec, DataType t, int i, const std::string& value) final; const std::string& vec, DataType t, int i, const std::string& value) final;
// overload visitor // overload visitor
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
// overload visitor // overload visitor
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
// reuse parent's function.
using CodeGenC::PrintType;
private: private:
int thread_index_bits_{32}; int thread_index_bits_{32};
......
...@@ -35,18 +35,17 @@ CodeGenOpenCL::CodeGenOpenCL() { ...@@ -35,18 +35,17 @@ CodeGenOpenCL::CodeGenOpenCL() {
restrict_keyword_ = "restrict"; restrict_keyword_ = "restrict";
} }
void CodeGenOpenCL::InitFuncState(LoweredFunc f) { void CodeGenOpenCL::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f); CodeGenC::InitFuncState(f);
for (Var arg : f->args) { for (Var arg : f->params) {
if (arg.dtype().is_handle()) { if (arg.dtype().is_handle()) {
alloc_storage_scope_[arg.get()] = "global"; alloc_storage_scope_[arg.get()] = "global";
} }
} }
} }
void CodeGenOpenCL::AddFunction(LoweredFunc f) { void CodeGenOpenCL::PrintFuncPrefix() {
this->stream << "__kernel "; stream << "__kernel void";
CodeGenC::AddFunction(f);
} }
std::string CodeGenOpenCL::Finish() { std::string CodeGenOpenCL::Finish() {
...@@ -239,50 +238,31 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NO ...@@ -239,50 +238,31 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode *op, std::ostream& os) { // NO
} }
} }
template<typename T> runtime::Module BuildOpenCL(IRModule mod) {
inline void PrintBinaryExpr(const T* op,
const char* opstr,
std::ostream& os,
CodeGenOpenCL* p) {
if (op->dtype.lanes() == 1) {
os << opstr << "((";
p->PrintType(op->a->dtype, os);
os << ")";
p->PrintExpr(op->a, os);
os << ", (";
p->PrintType(op->b->dtype, os);
os << ")";
p->PrintExpr(op->b, os);
os << ')';
} else {
p->PrintVecBinaryOp(opstr, op->dtype, op->a, op->b, os);
}
}
void CodeGenOpenCL::VisitExpr_(const MinNode *op, std::ostream& os) {
PrintBinaryExpr(op, "min", os, this);
}
void CodeGenOpenCL::VisitExpr_(const MaxNode *op, std::ostream& os) {
PrintBinaryExpr(op, "max", os, this);
}
runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry; using tvm::runtime::Registry;
bool output_ssa = false; bool output_ssa = false;
CodeGenOpenCL cg; CodeGenOpenCL cg;
cg.Init(output_ssa); cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenOpenCL: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f); cg.AddFunction(f);
} }
std::string code = cg.Finish(); std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) { if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) {
code = (*f)(code).operator std::string(); code = (*f)(code).operator std::string();
} }
return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs), code); return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(mod), code);
} }
TVM_REGISTER_GLOBAL("codegen.build_opencl") TVM_REGISTER_GLOBAL("target.build.opencl")
.set_body_typed(BuildOpenCL); .set_body_typed(BuildOpenCL);
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -34,11 +34,11 @@ namespace codegen { ...@@ -34,11 +34,11 @@ namespace codegen {
class CodeGenOpenCL final : public CodeGenC { class CodeGenOpenCL final : public CodeGenC {
public: public:
CodeGenOpenCL(); CodeGenOpenCL();
void AddFunction(LoweredFunc f);
std::string Finish(); std::string Finish();
// override print thread tag. // override print thread tag.
void InitFuncState(LoweredFunc f) final; void InitFuncState(const PrimFunc& f) final;
void PrintFuncPrefix() final; // NOLINT(*)
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const CallNode* op) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*)
...@@ -56,9 +56,6 @@ class CodeGenOpenCL final : public CodeGenC { ...@@ -56,9 +56,6 @@ class CodeGenOpenCL final : public CodeGenC {
// overload visitor // overload visitor
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; // NOLINT(*)
// overload min and max to avoid ambiguous call errors
void VisitExpr_(const MinNode *op, std::ostream& os) final;
void VisitExpr_(const MaxNode *op, std::ostream& os) final;
private: private:
// whether enable fp16 and fp64 extension // whether enable fp16 and fp64 extension
......
...@@ -37,7 +37,7 @@ namespace codegen { ...@@ -37,7 +37,7 @@ namespace codegen {
CodeGenOpenGL::CodeGenOpenGL() CodeGenOpenGL::CodeGenOpenGL()
: output_(nullptr), output_iter_var_(nullptr) {} : output_(nullptr), output_iter_var_(nullptr) {}
void CodeGenOpenGL::InitFuncState(LoweredFunc f) { void CodeGenOpenGL::InitFuncState(const PrimFunc& f) {
CodeGenC::InitFuncState(f); CodeGenC::InitFuncState(f);
output_ = nullptr; output_ = nullptr;
inputs_.clear(); inputs_.clear();
...@@ -47,7 +47,7 @@ void CodeGenOpenGL::InitFuncState(LoweredFunc f) { ...@@ -47,7 +47,7 @@ void CodeGenOpenGL::InitFuncState(LoweredFunc f) {
this->stream.str(""); this->stream.str("");
} }
void CodeGenOpenGL::AddFunction(LoweredFunc f) { void CodeGenOpenGL::AddFunction(const PrimFunc& f) {
// clear previous generated state. // clear previous generated state.
this->InitFuncState(f); this->InitFuncState(f);
...@@ -56,15 +56,17 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) { ...@@ -56,15 +56,17 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) {
// skip the first underscore, so SSA variable starts from _1 // skip the first underscore, so SSA variable starts from _1
GetUniqueName("_"); GetUniqueName("_");
// add to alloc buffer type.
for (const auto& kv : f->handle_data_type) {
RegisterHandleType(kv.first.get(), kv.second.dtype());
}
// Allocate argument names. Store in `var_idmap_`. // Allocate argument names. Store in `var_idmap_`.
for (auto arg : f->args) { for (auto arg : f->params) {
auto arg_name = GetUniqueName(arg.get()->name_hint); auto arg_name = GetUniqueName(arg.get()->name_hint);
var_idmap_[arg.get()] = arg_name; var_idmap_[arg.get()] = arg_name;
if (auto* ptr = arg->type_annotation.as<PointerTypeNode>()) {
if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
RegisterHandleType(arg.get(), prim->dtype);
}
}
} }
thread_extent_var_ = GetUniqueName("thread_extent"); thread_extent_var_ = GetUniqueName("thread_extent");
...@@ -80,7 +82,7 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) { ...@@ -80,7 +82,7 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) {
this->stream << "}\n\n"; this->stream << "}\n\n";
// Declare arguments. // Declare arguments.
for (auto arg : f->args) { for (auto arg : f->params) {
if (this->inputs_.find(arg.get()) != this->inputs_.cend()) { if (this->inputs_.find(arg.get()) != this->inputs_.cend()) {
// Declare input texture. // Declare input texture.
// Format: // Format:
...@@ -138,7 +140,7 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) { ...@@ -138,7 +140,7 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) {
std::vector<std::string> arg_names; std::vector<std::string> arg_names;
std::vector<runtime::OpenGLArgKind> arg_kinds; std::vector<runtime::OpenGLArgKind> arg_kinds;
for (auto arg : f->args) { for (auto arg : f->params) {
std::string name = GetVarID(arg.get()); std::string name = GetVarID(arg.get());
runtime::OpenGLArgKind kind; runtime::OpenGLArgKind kind;
...@@ -154,7 +156,11 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) { ...@@ -154,7 +156,11 @@ void CodeGenOpenGL::AddFunction(LoweredFunc f) {
arg_kinds.push_back(kind); arg_kinds.push_back(kind);
} }
shaders_[f->name] = runtime::OpenGLShader( auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute";
shaders_[static_cast<std::string>(global_symbol)] = runtime::OpenGLShader(
this->decl_stream.str() + this->stream.str(), this->decl_stream.str() + this->stream.str(),
std::move(arg_names), std::move(arg_kinds), std::move(arg_names), std::move(arg_kinds),
this->thread_extent_var_); this->thread_extent_var_);
...@@ -283,18 +289,27 @@ void CodeGenOpenGL::VisitStmt_(const EvaluateNode* op) { ...@@ -283,18 +289,27 @@ void CodeGenOpenGL::VisitStmt_(const EvaluateNode* op) {
this->stream << GetVarID(buffer) << " = " << PrintExpr(value) << ";\n"; this->stream << GetVarID(buffer) << " = " << PrintExpr(value) << ";\n";
} }
runtime::Module BuildOpenGL(Array<LoweredFunc> funcs) { runtime::Module BuildOpenGL(IRModule mod) {
bool output_ssa = false; bool output_ssa = false;
CodeGenOpenGL cg; CodeGenOpenGL cg;
cg.Init(output_ssa); cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenOpenGL: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenOpenGL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f); cg.AddFunction(f);
} }
auto shaders = cg.Finish(); auto shaders = cg.Finish();
return OpenGLModuleCreate(shaders, "gl", ExtractFuncInfo(funcs)); return OpenGLModuleCreate(shaders, "gl", ExtractFuncInfo(mod));
} }
TVM_REGISTER_GLOBAL("codegen.build_opengl") TVM_REGISTER_GLOBAL("target.build.opengl")
.set_body_typed(BuildOpenGL); .set_body_typed(BuildOpenGL);
} // namespace codegen } // namespace codegen
......
...@@ -37,10 +37,10 @@ namespace codegen { ...@@ -37,10 +37,10 @@ namespace codegen {
class CodeGenOpenGL final : public CodeGenC { class CodeGenOpenGL final : public CodeGenC {
public: public:
CodeGenOpenGL(); CodeGenOpenGL();
void AddFunction(LoweredFunc f);
std::unordered_map<std::string, runtime::OpenGLShader> Finish(); std::unordered_map<std::string, runtime::OpenGLShader> Finish();
void InitFuncState(LoweredFunc f) final; void AddFunction(const PrimFunc& f);
void InitFuncState(const PrimFunc& f) final;
void BindThreadIndex(const IterVar& iv) final; void BindThreadIndex(const IterVar& iv) final;
void VisitStmt_(const StoreNode* op) final; void VisitStmt_(const StoreNode* op) final;
std::string TexelFetch(const VarNode* buffer, PrimExpr index); std::string TexelFetch(const VarNode* buffer, PrimExpr index);
......
...@@ -68,14 +68,13 @@ void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) { ...@@ -68,14 +68,13 @@ void CodeGenVivadoHLS::PrintType(DataType t, std::ostream& os) {
} }
} }
void CodeGenVivadoHLS::AddFunction(LoweredFunc f) { void CodeGenVivadoHLS::PrintFuncPrefix() {
this->stream << "extern \"C\" "; stream << "extern \"C\" void";
CodeGenC::AddFunction(f);
} }
void CodeGenVivadoHLS::PreFunctionBody(LoweredFunc f) { void CodeGenVivadoHLS::PreFunctionBody(const PrimFunc& f) {
for (size_t i = 0; i < f->args.size(); ++i) { for (size_t i = 0; i < f->params.size(); ++i) {
Var v = f->args[i]; Var v = f->params[i];
std::string vid = GetVarID(v.get()); std::string vid = GetVarID(v.get());
if (v.dtype().is_handle()) { if (v.dtype().is_handle()) {
this->stream << "#pragma HLS INTERFACE m_axi port=" << vid << " offset=slave bundle=gmem\n"; this->stream << "#pragma HLS INTERFACE m_axi port=" << vid << " offset=slave bundle=gmem\n";
...@@ -126,21 +125,34 @@ void CodeGenVivadoHLS::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOL ...@@ -126,21 +125,34 @@ void CodeGenVivadoHLS::VisitExpr_(const MaxNode *op, std::ostream& os) { // NOL
} }
runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) { runtime::Module BuildSDAccel(IRModule mod, std::string target_str) {
using tvm::runtime::Registry; using tvm::runtime::Registry;
bool output_ssa = false; bool output_ssa = false;
CodeGenVivadoHLS cg; CodeGenVivadoHLS cg;
// Generate source code for get_source(). // Generate source code for get_source().
cg.Init(output_ssa); cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenVHLS: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f); cg.AddFunction(f);
} }
std::string whole_code = cg.Finish(); std::string whole_code = cg.Finish();
// Generate source code for compilation. // Generate source code for compilation.
Array<Array<PrimExpr> > kernel_info; Array<Array<PrimExpr> > kernel_info;
for (LoweredFunc f : funcs) {
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenOpenCL: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
CodeGenVivadoHLS cg; CodeGenVivadoHLS cg;
cg.Init(output_ssa); cg.Init(output_ssa);
cg.AddFunction(f); cg.AddFunction(f);
...@@ -148,7 +160,12 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) { ...@@ -148,7 +160,12 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) { if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) {
code = (*f)(code).operator std::string(); code = (*f)(code).operator std::string();
} }
kernel_info.push_back(Array<PrimExpr>({f->name, code}));
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
std::string func_name = global_symbol;
kernel_info.push_back(Array<PrimExpr>({func_name, code}));
} }
std::string xclbin; std::string xclbin;
...@@ -158,10 +175,10 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) { ...@@ -158,10 +175,10 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
} else { } else {
LOG(FATAL) << "Cannot compile Vivado HLS code."; LOG(FATAL) << "Cannot compile Vivado HLS code.";
} }
return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(funcs), whole_code); return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(mod), whole_code);
} }
TVM_REGISTER_GLOBAL("codegen.build_sdaccel") TVM_REGISTER_GLOBAL("target.build.sdaccel")
.set_body_typed(BuildSDAccel); .set_body_typed(BuildSDAccel);
} // namespace codegen } // namespace codegen
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
* KIND, either express or implied. See the License for the * KIND, either express or implied. See the License for the
* specific language governing permissions and limitations * specific language governing permissions and limitations
* under the License. * under the License.
*/ 5B5B */
/*! /*!
* \file codegen_vhls.h * \file codegen_vhls.h
...@@ -37,10 +37,11 @@ class CodeGenVivadoHLS final : public CodeGenC { ...@@ -37,10 +37,11 @@ class CodeGenVivadoHLS final : public CodeGenC {
public: public:
void Init(bool output_ssa); void Init(bool output_ssa);
void PrintType(DataType t, std::ostream& os); void PrintType(DataType t, std::ostream& os);
void AddFunction(LoweredFunc f);
void PreFunctionBody(LoweredFunc f); void PrintFuncPrefix() final;
void VisitExpr_(const MinNode *op, std::ostream& os); void PreFunctionBody(const PrimFunc& f) final;
void VisitExpr_(const MaxNode *op, std::ostream& os); void VisitExpr_(const MinNode *op, std::ostream& os) final;
void VisitExpr_(const MaxNode *op, std::ostream& os) final;
}; };
} // namespace codegen } // namespace codegen
......
...@@ -70,7 +70,7 @@ class SPIRVTools { ...@@ -70,7 +70,7 @@ class SPIRVTools {
spv_context ctx_; spv_context ctx_;
}; };
runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) { runtime::Module BuildSPIRV(IRModule mod) {
using tvm::runtime::Registry; using tvm::runtime::Registry;
using tvm::runtime::VulkanShader; using tvm::runtime::VulkanShader;
...@@ -81,8 +81,21 @@ runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) { ...@@ -81,8 +81,21 @@ runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) {
const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc"); const auto* postproc = Registry::Get("tvm_callback_vulkan_postproc");
CodeGenSPIRV cg; CodeGenSPIRV cg;
for (LoweredFunc f : funcs) {
f = PointerValueTypeRewrite(f); for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenSPIRV: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
std::string f_name = global_symbol;
f = PointerValueTypeRewrite(std::move(f));
VulkanShader shader; VulkanShader shader;
shader.data = cg.BuildFunction(f); shader.data = cg.BuildFunction(f);
...@@ -97,13 +110,14 @@ runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) { ...@@ -97,13 +110,14 @@ runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) {
reinterpret_cast<char*>(dmlc::BeginPtr(shader.data))); reinterpret_cast<char*>(dmlc::BeginPtr(shader.data)));
} }
code_data << spirv_tools.BinaryToText(shader.data); code_data << spirv_tools.BinaryToText(shader.data);
smap[f->name] = std::move(shader); smap[f_name] = std::move(shader);
} }
return runtime::VulkanModuleCreate( return runtime::VulkanModuleCreate(
smap, ExtractFuncInfo(funcs), code_data.str()); smap, ExtractFuncInfo(mod), code_data.str());
} }
TVM_REGISTER_GLOBAL("codegen.build_vulkan") TVM_REGISTER_GLOBAL("target.build.vulkan")
.set_body_typed(BuildSPIRV); .set_body_typed(BuildSPIRV);
} // namespace codegen } // namespace codegen
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
*/ */
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/runtime/container.h>
#include <string> #include <string>
#include "codegen_spirv.h" #include "codegen_spirv.h"
#include "../../arith/compute_expr.h" #include "../../arith/compute_expr.h"
...@@ -30,18 +31,20 @@ ...@@ -30,18 +31,20 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const LoweredFunc& f) { std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f) {
this->InitFuncState(); this->InitFuncState();
CHECK(f->is_restricted) CHECK(f->HasNonzeroAttr(tir::attr::kNoAlias))
<< "SPIRV only takes restricted memory model"; << "SPIRV only takes restricted memory model";
std::vector<Var> pod_args; std::vector<Var> pod_args;
uint32_t num_buffer = 0; uint32_t num_buffer = 0;
for (Var arg : f->args) {
for (Var arg : f->params) {
DataType t = arg.dtype(); DataType t = arg.dtype();
if (t.is_handle()) { if (t.is_handle()) {
auto it = f->handle_data_type.find(arg); if (auto* ptr = arg->type_annotation.as<PointerTypeNode>()) {
if (it != f->handle_data_type.end()) { auto* prim = ptr->element_type.as<PrimTypeNode>();
DataType value_type = (*it).second.dtype(); CHECK(prim);
DataType value_type = prim->dtype;
spirv::Value arg_value = builder_->BufferArgument( spirv::Value arg_value = builder_->BufferArgument(
builder_->GetSType(value_type), 0, num_buffer); builder_->GetSType(value_type), 0, num_buffer);
storage_info_[arg.get()].UpdateContentType(value_type); storage_info_[arg.get()].UpdateContentType(value_type);
...@@ -75,7 +78,11 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const LoweredFunc& f) { ...@@ -75,7 +78,11 @@ std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const LoweredFunc& f) {
builder_->MakeInst(spv::OpReturn); builder_->MakeInst(spv::OpReturn);
builder_->MakeInst(spv::OpFunctionEnd); builder_->MakeInst(spv::OpFunctionEnd);
builder_->CommitKernelFunction(func_ptr, f->name); auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
builder_->CommitKernelFunction(func_ptr, static_cast<std::string>(global_symbol));
return builder_->Finalize(); return builder_->Finalize();
} }
...@@ -607,7 +614,7 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { ...@@ -607,7 +614,7 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
} }
void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent) { if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag.length() != 0) { if (iv->thread_tag.length() != 0) {
if (!var_map_.count(iv->var.get())) { if (!var_map_.count(iv->var.get())) {
......
...@@ -53,7 +53,7 @@ class CodeGenSPIRV: ...@@ -53,7 +53,7 @@ class CodeGenSPIRV:
* \param f The function to be added. * \param f The function to be added.
* \return The final spirv module. * \return The final spirv module.
*/ */
virtual std::vector<uint32_t> BuildFunction(const LoweredFunc& f); virtual std::vector<uint32_t> BuildFunction(const PrimFunc& f);
/*! /*!
* \brief Create Value for expression e * \brief Create Value for expression e
* \param e The expression to be created value for. * \param e The expression to be created value for.
......
...@@ -21,7 +21,10 @@ ...@@ -21,7 +21,10 @@
* \file codegen_stackvm.cc * \file codegen_stackvm.cc
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/ir/module.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/tir/function.h>
#include <limits> #include <limits>
#include <utility> #include <utility>
#include "codegen_stackvm.h" #include "codegen_stackvm.h"
...@@ -54,9 +57,9 @@ StackVM::StructFieldKind MapFieldKind(int64_t kind) { ...@@ -54,9 +57,9 @@ StackVM::StructFieldKind MapFieldKind(int64_t kind) {
return StackVM::kArrData; return StackVM::kArrData;
} }
StackVM CodeGenStackVM::Compile(LoweredFunc f) { StackVM CodeGenStackVM::Compile(const PrimFunc& f) {
for (size_t i = 0; i < f->args.size(); ++i) { for (size_t i = 0; i < f->params.size(); ++i) {
Var v = f->args[i]; Var v = f->params[i];
int vid = AllocVarID(v.get()); int vid = AllocVarID(v.get());
CHECK_EQ(static_cast<size_t>(vid), i); CHECK_EQ(static_cast<size_t>(vid), i);
} }
...@@ -525,19 +528,32 @@ void CodeGenStackVM::VisitExpr_(const LetNode* op) { ...@@ -525,19 +528,32 @@ void CodeGenStackVM::VisitExpr_(const LetNode* op) {
this->Push(op->body); this->Push(op->body);
} }
runtime::Module BuildStackVM(const Array<LoweredFunc>& funcs) { runtime::Module BuildStackVM(const IRModule& mod) {
CHECK_NE(funcs.size(), 0U);
std::unordered_map<std::string, StackVM> fmap; std::unordered_map<std::string, StackVM> fmap;
for (LoweredFunc f : funcs) { std::string entry_func;
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenStackVM: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute";
std::string f_name = global_symbol;
StackVM vm = codegen::CodeGenStackVM().Compile(f); StackVM vm = codegen::CodeGenStackVM().Compile(f);
CHECK(!fmap.count(f->name)) CHECK(!fmap.count(f_name))
<< "Function name " << f->name << "already exist in list"; << "Function name " << f_name << "already exist in list";
fmap[f->name] = std::move(vm); fmap[f_name] = std::move(vm);
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
entry_func = f_name;
}
} }
return runtime::StackVMModuleCreate(fmap, funcs[0]->name);
return runtime::StackVMModuleCreate(fmap, entry_func);
} }
TVM_REGISTER_GLOBAL("codegen.build_stackvm") TVM_REGISTER_GLOBAL("target.build.stackvm")
.set_body_typed(BuildStackVM); .set_body_typed(BuildStackVM);
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -56,7 +56,7 @@ class CodeGenStackVM ...@@ -56,7 +56,7 @@ class CodeGenStackVM
* \note Only call compile once, * \note Only call compile once,
* create a new codegen object each time. * create a new codegen object each time.
*/ */
StackVM Compile(LoweredFunc f); StackVM Compile(const PrimFunc& f);
/*! \brief Push stmt to generate new code */ /*! \brief Push stmt to generate new code */
void Push(const Stmt& n); void Push(const Stmt& n);
/*! \brief Push expr to generate new code */ /*! \brief Push expr to generate new code */
......
...@@ -91,7 +91,7 @@ Stmt MakeCrossThreadReduction( ...@@ -91,7 +91,7 @@ Stmt MakeCrossThreadReduction(
freduce_args, CallNode::Intrinsic)); freduce_args, CallNode::Intrinsic));
reduce_body = AttrStmtNode::make( reduce_body = AttrStmtNode::make(
reduces[0]->combiner, reduces[0]->combiner,
attr::reduce_scope, tir::attr::reduce_scope,
make_zero(DataType::Handle()), make_zero(DataType::Handle()),
reduce_body); reduce_body);
std::vector<Stmt> assigns(size); std::vector<Stmt> assigns(size);
...@@ -109,7 +109,7 @@ Stmt MakeCrossThreadReduction( ...@@ -109,7 +109,7 @@ Stmt MakeCrossThreadReduction(
body = AllocateNode::make( body = AllocateNode::make(
res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body); res_handles[idx - 1], reduces[idx - 1]->dtype, {1}, const_true(), body);
body = AttrStmtNode::make( body = AttrStmtNode::make(
res_handles[idx - 1], attr::storage_scope, StringImmNode::make("local"), body); res_handles[idx - 1], tir::attr::storage_scope, StringImmNode::make("local"), body);
} }
body = Substitute(body, value_map); body = Substitute(body, value_map);
return MergeNest(nest, body); return MergeNest(nest, body);
......
...@@ -165,7 +165,8 @@ Stmt ExternOpNode::BuildProvide( ...@@ -165,7 +165,8 @@ Stmt ExternOpNode::BuildProvide(
const std::unordered_map<IterVar, Range>& dom_map, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const { bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this); CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body); Stmt ret = AttrStmtNode::make(
make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) {
Array<ObjectRef> bind_spec; Array<ObjectRef> bind_spec;
Array<PrimExpr> tuple; Array<PrimExpr> tuple;
...@@ -176,7 +177,7 @@ Stmt ExternOpNode::BuildProvide( ...@@ -176,7 +177,7 @@ Stmt ExternOpNode::BuildProvide(
tuple.push_back(buffer->shape[k]); tuple.push_back(buffer->shape[k]);
} }
ret = AttrStmtNode::make( ret = AttrStmtNode::make(
bind_spec, attr::buffer_bind_scope, bind_spec, tir::attr::buffer_bind_scope,
CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret); CallNode::make(DataType::Handle(), intrinsic::tvm_tuple, tuple, CallNode::Intrinsic), ret);
}; };
for (size_t i = output_placeholders.size(); i != 0; --i) { for (size_t i = output_placeholders.size(); i != 0; --i) {
......
...@@ -186,7 +186,8 @@ Stmt HybridOpNode::BuildProvide( ...@@ -186,7 +186,8 @@ Stmt HybridOpNode::BuildProvide(
const std::unordered_map<IterVar, Range> &dom_map, const std::unordered_map<IterVar, Range> &dom_map,
bool debug_keep_trivial_loop) const { bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this); CHECK_EQ(stage->op.operator->(), this);
Stmt ret = AttrStmtNode::make(make_zero(DataType::Int(32)), attr::extern_scope, 0, this->body); Stmt ret = AttrStmtNode::make(
make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
std::unordered_map<Tensor, Tensor> rmap; std::unordered_map<Tensor, Tensor> rmap;
for (int i = 0; i < this->num_outputs(); ++i) { for (int i = 0; i < this->num_outputs(); ++i) {
rmap[outputs[i]] = stage->op.output(i); rmap[outputs[i]] = stage->op.output(i);
......
...@@ -168,7 +168,7 @@ MakeLoopNest(const Stage& stage, ...@@ -168,7 +168,7 @@ MakeLoopNest(const Stage& stage,
// annotate the extent of the IterVar // annotate the extent of the IterVar
if (!new_loop_var) { if (!new_loop_var) {
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
AttrStmtNode::make(iv, attr::loop_scope, iv->var, no_op)); AttrStmtNode::make(iv, tir::attr::loop_scope, iv->var, no_op));
} }
} }
// message passing to get offset of root iter vars. // message passing to get offset of root iter vars.
......
...@@ -287,10 +287,10 @@ Stmt ScanOpNode::BuildProvide( ...@@ -287,10 +287,10 @@ Stmt ScanOpNode::BuildProvide(
bool debug_keep_trivial_loop) const { bool debug_keep_trivial_loop) const {
CHECK_EQ(stage->op.operator->(), this); CHECK_EQ(stage->op.operator->(), this);
Stmt provide = AttrStmtNode::make( Stmt provide = AttrStmtNode::make(
stage->op, attr::scan_update_scope, this->scan_axis->var, stage->op, tir::attr::scan_update_scope, this->scan_axis->var,
EvaluateNode::make(0)); EvaluateNode::make(0));
Stmt init = AttrStmtNode::make( Stmt init = AttrStmtNode::make(
stage->op, attr::scan_init_scope, 0, stage->op, tir::attr::scan_init_scope, 0,
EvaluateNode::make(0)); EvaluateNode::make(0));
size_t begin_scan = 0; size_t begin_scan = 0;
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
......
...@@ -85,7 +85,7 @@ class InjectAttach : public StmtMutator { ...@@ -85,7 +85,7 @@ class InjectAttach : public StmtMutator {
auto stmt = StmtMutator::VisitStmt(input_stmt); auto stmt = StmtMutator::VisitStmt(input_stmt);
const AttrStmtNode* op = stmt.as<AttrStmtNode>(); const AttrStmtNode* op = stmt.as<AttrStmtNode>();
if (op != nullptr && if (op != nullptr &&
op->attr_key == attr::loop_scope) { op->attr_key == tir::attr::loop_scope) {
if (attach_spec_->attach_type == kScope && if (attach_spec_->attach_type == kScope &&
op->node == attach_spec_->attach_ivar) { op->node == attach_spec_->attach_ivar) {
CHECK(!found_attach) CHECK(!found_attach)
...@@ -131,8 +131,8 @@ class InjectScanStep : public StmtMutator { ...@@ -131,8 +131,8 @@ class InjectScanStep : public StmtMutator {
// update // update
const AttrStmtNode* op = stmt.as<AttrStmtNode>(); const AttrStmtNode* op = stmt.as<AttrStmtNode>();
if (op != nullptr && if (op != nullptr &&
((op->attr_key == attr::scan_update_scope && !is_init_) || ((op->attr_key == tir::attr::scan_update_scope && !is_init_) ||
(op->attr_key == attr::scan_init_scope && is_init_))) { (op->attr_key == tir::attr::scan_init_scope && is_init_))) {
if (op->node.same_as(scan_op_)) { if (op->node.same_as(scan_op_)) {
found_attach = true; found_attach = true;
stmt = AttrStmtNode::make( stmt = AttrStmtNode::make(
...@@ -187,15 +187,15 @@ class SchedulePostProc : public StmtExprMutator { ...@@ -187,15 +187,15 @@ class SchedulePostProc : public StmtExprMutator {
} }
Stmt VisitStmt_(const AttrStmtNode* op) final { Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::loop_scope || if (op->attr_key == tir::attr::loop_scope ||
op->attr_key == attr::scan_init_scope) { op->attr_key == tir::attr::scan_init_scope) {
return this->VisitStmt(op->body); return this->VisitStmt(op->body);
} else if (op->attr_key == attr::scan_update_scope) { } else if (op->attr_key == tir::attr::scan_update_scope) {
const ScanOpNode* scan = op->node.as<ScanOpNode>(); const ScanOpNode* scan = op->node.as<ScanOpNode>();
CHECK(scan); CHECK(scan);
var_value_[scan->scan_axis->var.get()] = op->value; var_value_[scan->scan_axis->var.get()] = op->value;
return this->VisitStmt(op->body); return this->VisitStmt(op->body);
} else if (op->attr_key == attr::thread_extent) { } else if (op->attr_key == tir::attr::thread_extent) {
// delete duplicated thread extent attr // delete duplicated thread extent attr
auto it = thread_extent_scope_.find(op->node.get()); auto it = thread_extent_scope_.find(op->node.get());
if (it != thread_extent_scope_.end()) { if (it != thread_extent_scope_.end()) {
......
...@@ -32,25 +32,51 @@ ...@@ -32,25 +32,51 @@
namespace tvm { namespace tvm {
namespace tir { namespace tir {
Var::Var(std::string name_hint, DataType t) Var::Var(std::string name_hint, DataType dtype) {
: Var(make_object<VarNode>(t, name_hint)) {} auto n = make_object<VarNode>();
n->name_hint = std::move(name_hint);
n->dtype = std::move(dtype);
data_ = std::move(n);
}
VarNode::VarNode(DataType t, std::string name_hint) { Var::Var(std::string name_hint, Type type_annotation) {
this->dtype = t; auto n = make_object<VarNode>();
this->name_hint = std::move(name_hint); n->name_hint = std::move(name_hint);
n->dtype = GetRuntimeDataType(type_annotation);
n->type_annotation = std::move(type_annotation);
data_ = std::move(n);
} }
SizeVar::SizeVar(std::string name_hint, DataType t)
: SizeVar(make_object<SizeVarNode>(t, name_hint)) {}
SizeVarNode::SizeVarNode(DataType t, std::string name_hint) Var Var::copy_with_suffix(const std::string& suffix) const {
: VarNode(t, std::move(name_hint)) {} const VarNode* node = get();
ObjectPtr<VarNode> new_ptr;
if (auto* ptr = this->as<SizeVarNode>()) {
new_ptr = make_object<SizeVarNode>(*ptr);
} else {
new_ptr = make_object<VarNode>(*node);
}
new_ptr->name_hint += suffix;
return Var(new_ptr);
}
SizeVar::SizeVar(std::string name_hint, DataType dtype) {
auto n = make_object<SizeVarNode>();
n->name_hint = std::move(name_hint);
n->dtype = std::move(dtype);
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("tir.Var") TVM_REGISTER_GLOBAL("tir.Var")
.set_body_typed([](std::string s, DataType t) { .set_body_typed([](std::string name_hint, runtime::TVMArgValue type) {
return Var(s, t); if (type.IsObjectRef<Type>()) {
}); return Var(name_hint, type.operator Type());
} else {
return Var(name_hint, type.operator DataType());
}
});
TVM_REGISTER_GLOBAL("tir.SizeVar") TVM_REGISTER_GLOBAL("tir.SizeVar")
.set_body_typed([](std::string s, DataType t) { .set_body_typed([](std::string s, DataType t) {
......
...@@ -31,5 +31,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -31,5 +31,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
}); });
TVM_REGISTER_NODE_TYPE(LoweredFuncNode); TVM_REGISTER_NODE_TYPE(LoweredFuncNode);
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -33,14 +33,34 @@ namespace tvm { ...@@ -33,14 +33,34 @@ namespace tvm {
using namespace tir; using namespace tir;
runtime::DataType GetRuntimeDataType(const Type& type) {
if (auto * n = type.as<PrimTypeNode>()) {
return n->dtype;
} else if (type.as<PointerTypeNode>()) {
return DataType::Handle();
} else {
LOG(FATAL) << "Type " << type
<< " does not have a corresponding runtime::DataType";
return DataType::Handle();
}
}
Type GetType(const PrimExpr& expr) { Type GetType(const PrimExpr& expr) {
// TODO(tqchen): add recursive type inference for Call here
// once we introduced the corresponding fields to the IR.
if (auto* ptr = expr.as<tir::VarNode>()) {
// If Var has a more refined type annotation,
// return the type anotation
if (ptr->type_annotation.defined()) {
return ptr->type_annotation;
}
}
// Default: return the type indicated by the dtype.
runtime::DataType dtype = expr.dtype(); runtime::DataType dtype = expr.dtype();
// These types already implies the specific type. // These types already implies the specific type.
if (dtype.is_int() || dtype.is_uint() || dtype.is_float()) { if (dtype.is_int() || dtype.is_uint() || dtype.is_float()) {
return PrimType(dtype); return PrimType(dtype);
} }
// TODO(tqchen): add recursive type inference for Var and Call here
// once we introduced the corresponding fields to the IR.
return PrimType(dtype); return PrimType(dtype);
} }
......
...@@ -68,6 +68,32 @@ class IRSubstitue : public StmtExprMutator { ...@@ -68,6 +68,32 @@ class IRSubstitue : public StmtExprMutator {
} }
} }
PrimExpr VisitExpr_(const LoadNode* op) final {
// NOTE: we do not explicit recursivly mutate op->buffer_var
PrimExpr ret = StmtExprMutator::VisitExpr_(op);
op = ret.as<LoadNode>();
auto it = smap_.find(op->buffer_var.get());
if (it != smap_.end()) {
return LoadNode::make(
op->dtype, Downcast<Var>(it->second), op->index, op->predicate);
} else {
return ret;
}
}
Stmt VisitStmt_(const StoreNode* op) final {
// NOTE: we do not explicit recursivly mutate op->buffer_var
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<StoreNode>();
auto it = smap_.find(op->buffer_var.get());
if (it != smap_.end()) {
return StoreNode::make(
Downcast<Var>(it->second), op->value, op->index, op->predicate);
} else {
return ret;
}
}
private: private:
const std::unordered_map<const VarNode*, PrimExpr>& smap_; const std::unordered_map<const VarNode*, PrimExpr>& smap_;
}; };
......
...@@ -1016,6 +1016,47 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f) { ...@@ -1016,6 +1016,47 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
return LoweredFunc(n); return LoweredFunc(n);
} }
PrimFunc PointerValueTypeRewrite(PrimFunc f) {
auto* n = f.CopyOnWrite();
VectorAllocRewriter rewriter;
n->body = rewriter(n->body);
Array<tir::Var> args;
Map<tir::Var, PrimExpr> remap_vars;
for (Var var : f->params) {
if (var.dtype().is_handle()) {
const auto& tvec = rewriter.acc_map_[var.get()];
if (tvec.size() == 1) {
tir::Var new_var(var->name_hint,
PointerType(PrimType(tvec[0])));
args.push_back(new_var);
remap_vars.Set(var, new_var);
} else {
// always set data type to be non vectorized so
// load/store can still work via scalarization
if (tvec.size() != 0 && !var->type_annotation.defined()) {
tir::Var new_var(var->name_hint,
PointerType(PrimType(tvec[0].with_lanes(1))));
args.push_back(new_var);
remap_vars.Set(var, new_var);
} else {
args.push_back(var);
}
}
} else {
args.push_back(var);
}
}
CHECK_EQ(args.size(), n->params.size());
n->params = args;
n->body = Substitute(n->body, remap_vars);
return f;
}
Stmt StorageRewrite(Stmt stmt) { Stmt StorageRewrite(Stmt stmt) {
stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true); stmt = StoragePlanRewriter().Rewrite(std::move(stmt), true);
return VectorAllocRewriter()(std::move(stmt)); return VectorAllocRewriter()(std::move(stmt));
......
...@@ -83,8 +83,8 @@ def test_meta_data(): ...@@ -83,8 +83,8 @@ def test_meta_data():
text_no_meta = str(f) text_no_meta = str(f)
assert "channels=2" in text assert "channels=2" in text
assert "channels=2" in text_no_meta assert "channels=2" in text_no_meta
assert "meta[SizeVar][0]" in text assert "meta[tir.SizeVar][0]" in text
assert "meta[SizeVar][0]" in text_no_meta assert "meta[tir.SizeVar][0]" in text_no_meta
assert "type_key" in text assert "type_key" in text
assert "type_key" not in text_no_meta assert "type_key" not in text_no_meta
......
...@@ -108,8 +108,32 @@ def test_global_var(): ...@@ -108,8 +108,32 @@ def test_global_var():
assert isinstance(tvar, tvm.ir.GlobalVar) assert isinstance(tvar, tvm.ir.GlobalVar)
def test_tir_var():
nodes = [
{"type_key": ""},
{"type_key": "Variable",
"attrs": {"dtype": "int32", "name": "x"}},
{"type_key": "SizeVar",
"attrs": {"dtype": "int32", "name": "y"}},
]
data = {
"root" : 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
x = tvm.ir.load_json(json.dumps(data))
assert isinstance(x, tvm.tir.Var)
assert x.name == "x"
data["root"] = 2
y = tvm.ir.load_json(json.dumps(data))
assert isinstance(y, tvm.tir.SizeVar)
assert y.name == "y"
if __name__ == "__main__": if __name__ == "__main__":
test_type_var() test_type_var()
test_incomplete_type() test_incomplete_type()
test_func_tuple_type() test_func_tuple_type()
test_global_var() test_global_var()
test_tir_var()
...@@ -265,7 +265,18 @@ def test_prim_func(): ...@@ -265,7 +265,18 @@ def test_prim_func():
assert func.attrs is None assert func.attrs is None
def test_vars():
x = tvm.tir.Var("xyz", "int8")
assert x.dtype == "int8"
ptype = tvm.ir.PointerType(tvm.ir.PrimType("float"))
x = tvm.tir.Var("xyz", ptype)
assert x.dtype == "handle"
assert x.type_annotation == ptype
assert isinstance(ptype.element_type, tvm.ir.PrimType)
if __name__ == "__main__": if __name__ == "__main__":
test_vars()
test_prim_func() test_prim_func()
test_cast() test_cast()
test_attr() test_attr()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment