Commit 4f1473f3 by Tianqi Chen Committed by GitHub

[CODEGEN] Add LoweredFunc, MakeAPI to build a C API function (#23)

* [CODEGEN] Add LoweredFunc, MakeAPI and SplitHostDevice

* update halideir
parent 3c1020df
Subproject commit adfa662402650e2f9b02ea600ffb70d6e7bb5adf
Subproject commit 30bf0f043e6388418958fd1f29259ee43c42b600
......@@ -50,6 +50,9 @@ class Buffer : public NodeRef {
* \return the pointer to the internal node container
*/
inline const BufferNode* operator->() const;
/*! \brief specify container node */
using ContainerType = BufferNode;
};
/*! \brief Node to represent a buffer */
......
......@@ -30,6 +30,7 @@
#endif
#include <stdint.h>
#include <stddef.h>
TVM_EXTERN_C {
......@@ -216,18 +217,45 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
/*!
* \brief Launch a generated TVM function
* \brief TVM Function API: Get resource requirement
*
* By default TVM function try not to do internal allocations.
* Instead, TVMFuncRequirement can be called, given the input arguments.
*
* \param func function handle to be launched.
* \param args The arguments
* \param arg_type_ids The type id of the arguments
* \param num_args Number of arguments.
* \param out_workspace_size The workspace size needed to launch this function.
* \param out_workspace_align The alignment requirement of workspace.
*
* \note The data pointer in the arrays is not used by requirement.
*/
TVM_DLL int TVMFuncRequirement(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
size_t* out_workspace_size,
size_t* out_workspace_align);
/*!
* \brief TVM Function API: Launch generated function.
*
* \param func function handle to be launched.
* \param args The arguments
* \param arg_type_ids The type id of the arguments
* \param num_args Number of arguments.
* \param stream The stream this function to be launched on.
* \param workspace Additional workspace used to launch this function.
*
* \sa TVMFuncRequirement
*/
TVM_DLL int TVMLaunch(TVMFunctionHandle func,
TVM_DLL int TVMFuncLaunch(TVMFunctionHandle func,
TVMArg* args,
int* arg_type_ids,
int num_args,
TVMStreamHandle stream);
TVMStreamHandle stream,
TVMArrayHandle workspace);
} // TVM_EXTERN_C
#endif // TVM_C_RUNTIME_API_H_
/*!
* Copyright (c) 2016 by Contributors
* \file codegen.h
* \brief Collection of Lowlevel IR pass to codegen.
*/
#ifndef TVM_CODEGEN_H_
#define TVM_CODEGEN_H_
#include <string>
#include "./base.h"
#include "./expr.h"
#include "./module.h"
namespace tvm {
/*! \brief namespace for lowlevel IR pass and codegen */
namespace codegen {
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
* - Map the values in the api_args to of Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_packed_args Number of arguments that are processed in packed form.
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n)
*
* where n == len(api_args), k == num_packed_args
*
* There is no thread_axis in generated function.
*/
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args);
/*!
* \brief Count number of undefined vars in f.
* \param f The function to be checked.
* \return Number of undefined vars.
*/
Array<Var> UndefinedVars(const LoweredFunc& f);
/*!
* \brief Split the function into a host function and device functions.
* \param func The function to be splitted.
*
* \return Array of functions, the first one is host function,
* the others are device functions.
*/
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_H_
......@@ -49,6 +49,48 @@ struct Reduce : public ExprNode<Reduce> {
static constexpr const char* Min = "Min";
};
/*! \brief namespace of TVM Intrinsic functions */
namespace intrinsic {
// Most of the intrinsics is to enab
/*!
* \brief See pesudo code
*
* Type tvm_api_load_arg(TVMArg* args, int* args_type_id, i) {
* assert(arg_type_id[i] == typeid(Type));
* return args[i];
* }
*/
constexpr const char* tvm_api_load_arg = "tvm_api_load_arg";
/*!
* \brief See pesudo code
*
* Type tvm_array_get_field(TVMArray* arr, int field_id) {
* return arr->field;
* }
* \sa TVMArrayFieldKind
*/
constexpr const char* tvm_array_get_field = "tvm_array_get_field";
/*!
* \brief See pesudo code
*
* bool tvm_handle_is_null(void* handle) {
* return handle == nullptr
* }
*/
constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
/*! \brief The field id of each field in array */
enum TVMArrayFieldKind {
kData = 0,
kNDim = 1,
kShape = 2,
kStrides = 3,
kTypeCode = 4,
kTypeBits = 5,
kTypeLanes = 6
};
} // namespace intrinsic
// Reuse IR node defintiion from HalideIR
using Halide::Internal::IntImm;
using Halide::Internal::UIntImm;
......
......@@ -9,6 +9,7 @@
#include <tvm/ir_functor.h>
#include <unordered_map>
#include "./expr.h"
#include "./ir.h"
namespace tvm {
namespace ir {
......@@ -51,6 +52,20 @@ class IRMutator {
static FMutateExpr& vtable_expr(); // NOLINT(*)
/*! \return internal stmt of expr */
static FMutateStmt& vtable_stmt(); // NOLINT(*)
// Set of overloadable functions
// The underscore allows Mutate not to be shadowed by inheritance
virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
virtual Stmt Mutate_(const For* op, const Stmt& s);
virtual Stmt Mutate_(const Provide* op, const Stmt& s);
virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& s);
virtual Expr Mutate_(const Variable* op, const Expr& e);
virtual Expr Mutate_(const Let* op, const Expr& e);
};
/*!
......
......@@ -57,6 +57,12 @@ Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
bool VerifySSA(const Stmt& ir);
/*!
* \brief Whether the expression have side effect.
* \return whether expression have side effect
*/
bool HasSideEffect(const Expr& e);
/*!
* \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted.
* \return The converted form.
......@@ -79,7 +85,6 @@ Stmt Inline(Stmt stmt,
Array<Var> args,
Expr body);
/*!
* \brief Flatten the multi-dimensional read/write
* to single dimensional Load/Store
......
......@@ -34,6 +34,17 @@ class IRVisitor {
using FVisit = IRFunctor<void(const NodeRef&, IRVisitor*)>;
/*! \return internal vtable*/
static FVisit& vtable();
// overloadable visit function.
virtual void Visit_(const Variable* op);
virtual void Visit_(const AttrStmt* op);
virtual void Visit_(const LetStmt* op);
virtual void Visit_(const For* op);
virtual void Visit_(const Allocate* op);
virtual void Visit_(const Load* op);
virtual void Visit_(const Store* op);
virtual void Visit_(const Let* op);
virtual void Visit_(const Free* op);
virtual void Visit_(const Call* op);
};
/*!
......
/*!
* Copyright (c) 2016 by Contributors
* \file module.h
* \brief Low level IR module,
* Contains lowered function information.
*/
#ifndef TVM_MODULE_H_
#define TVM_MODULE_H_
#include <tvm/container.h>
#include <ir/FunctionBase.h>
#include <string>
#include "./base.h"
#include "./expr.h"
#include "./tensor.h"
namespace tvm {
// Internal node container of lowered function.
class LoweredFuncNode;
// Internal node container of module.
class ModuleNode;
/*!
* \brief LoweredFunc represents function after lowering.
* This is the final IR representation before codegen.
*/
class LoweredFunc : public FunctionRef {
public:
LoweredFunc() {}
explicit LoweredFunc(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const LoweredFuncNode* operator->() const;
/*! \brief specify container node */
using ContainerType = LoweredFuncNode;
};
/*! \brief Node container of LoweredFunc */
class LoweredFuncNode : public FunctionBaseNode {
public:
/*! \brief The name of the function */
std::string name;
/*!
* \brief The arguments of the function
* This function can only take pod type(int, float) and void* as arguments.
*/
Array<Var> args;
/*!
* \brief The IterVar axis of threads
* Each axis need host function to specify a size.
* \note Calling convention into LoweredFunc
*
* Assume we have a LoweredFunc f, a call into f
* Call(f, arg1, arg2, ..., arg_n,
* size_axis_1, size_axis_2, ... size_axis_m)
*
* Here n = len(args), m = len(thread_axis)
*
* The CodeGen should take this and translate this call
* to corresponding API specific kernel launchs or function calls.
*/
Array<IterVar> thread_axis;
/*!
* \brief The hint data type of Var handles defined in LetStmt
* Can be used as hint when generating type signiture.
* The creation rule is given by
* handle_data_type[var_handle] = make_const(the_type, 0);
*
* \note Expr is used instead Type, because Type cannot be hold by Map.
* constant Expr of given type is used.
*/
Map<Var, Expr> handle_data_type;
/*! \brief The body statment of the function */
Stmt body;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
// there is no return value, but return 1
// to enable Call into this function.
int num_outputs() const final {
return 1;
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("thread_axis", &thread_axis);
v->Visit("handle_data_type", &handle_data_type);
v->Visit("body", &body);
}
static constexpr const char* _type_key = "LoweredFunc";
TVM_DECLARE_NODE_TYPE_INFO(LoweredFuncNode);
};
// Implementations of inline functions
inline const LoweredFuncNode* LoweredFunc::operator->() const {
return static_cast<const LoweredFuncNode*>(node_.get());
}
} // namespace tvm
#endif // TVM_MODULE_H_
......@@ -56,3 +56,9 @@ class IterVar(NodeBase, _expr.ExprOp):
class Buffer(NodeBase):
"""Represent a Buffer in TVM."""
pass
@register_node
class LoweredFunc(NodeBase):
"""Represent a LoweredFunc in TVM."""
pass
......@@ -7,6 +7,7 @@
#define TVM_BASE_COMMON_H_
#include <tvm/base.h>
#include <tvm/expr.h>
#include <string>
namespace tvm {
......@@ -30,7 +31,7 @@ inline Type String2Type(std::string s) {
} else if (s.substr(0, 5) == "float") {
code = Type::Float; s = s.substr(5);
} else if (s == "handle") {
return Type(Type::Handle, 32, 1);
return Handle();
} else {
LOG(FATAL) << "unknown type " << s;
}
......
......@@ -5,6 +5,7 @@
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/codegen.h>
#include "./c_api_registry.h"
#include "../codegen/codegen_c.h"
......@@ -17,9 +18,19 @@ using RetValue = APIVariantValue;
TVM_REGISTER_API(_codegen_CompileToC)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = CodeGenC().Compile(
*ret = CodeGenC().Compile(args.at(0), args.at(1));
});
TVM_REGISTER_API(_codegen_MakeAPI)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = MakeAPI(
args.at(0), args.at(1), args.at(2), args.at(3));
});
TVM_REGISTER_API(_codegen_SplitHostDevice)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = SplitHostDevice(args.at(0));
});
} // namespace codegen
} // namespace tvm
......@@ -8,6 +8,7 @@
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/module.h>
#include <string>
#include <unordered_map>
......@@ -23,16 +24,12 @@ class CodeGenC {
public:
/*!
* \brief Generate the C code of statement
* \param body The body of the function.
* \param fun_name The name of the function.
* \param args The arguments to the function.
* \param f The function to be compiled
* \param output_ssa Whether output ssa form.
* \note Only call compile once,
* create a new codegen object each time.
*/
std::string Compile(Stmt body,
std::string fun_name,
Array<Var> args,
std::string Compile(LoweredFunc f,
bool output_ssa);
/*!
* \brief Print the Stmt n to CodeGenC->stream
......@@ -49,7 +46,7 @@ class CodeGenC {
* \brief Same as PrintExpr, but simply returns result string
* \param n The expression to be printed.
*/
inline std::string PrintExpr(const Expr& n) {
std::string PrintExpr(const Expr& n) {
std::ostringstream os;
PrintExpr(n, os);
return os.str();
......@@ -85,7 +82,9 @@ class CodeGenC {
virtual void PrintStmt(const ir::Store* op);
virtual void PrintStmt(const ir::Allocate* op);
virtual void PrintStmt(const ir::AttrStmt* op);
virtual void PrintStmt(const ir::AssertStmt* op);
virtual void PrintExpr(const ir::Load* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Call* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Let* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*)
......@@ -116,7 +115,13 @@ class CodeGenC {
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
bool BufferTypeMatch(const Variable* buf_var, Type t) const;
bool HandleTypeMatch(const Variable* buf_var, Type t) const;
/*!
* \brief Register the data type of buf_var
* \param buf_var The buffer variable.
* \param t The type to be checked.
*/
void HandleTypeRegister(const Variable* buf_var, Type t);
/*!
* \brief get a unique name with the corresponding prefix
* \param prefix The prefix of the name
......@@ -128,7 +133,7 @@ class CodeGenC {
/*! \brief name of each variable */
std::unordered_map<const Variable*, std::string> var_idmap_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, Type> alloc_buf_type_;
std::unordered_map<const Variable*, Type> handle_data_type_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief assignment map of ssa */
......
/*!
* Copyright (c) 2017 by Contributors
* \file make_api.cc Build API function.
*/
#include <tvm/codegen.h>
#include <tvm/ir.h>
#include <tvm/buffer.h>
#include <vector>
#include <utility>
#include <unordered_set>
#include "../pass/ir_util.h"
namespace tvm {
namespace codegen {
using namespace ir;
inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMArrayFieldKind kind) {
return Call::make(
t, intrinsic::tvm_array_get_field,
{arr, IntImm::make(Int(32), kind)},
Call::PureIntrinsic);
}
inline Stmt AssertNull(Var handle, std::string msg) {
return AssertStmt::make(Call::make(
Bool(1), intrinsic::tvm_handle_is_null,
{handle}, Call::PureIntrinsic), msg);
}
inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) {
return AssertStmt::make(lhs == rhs, msg);
}
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args) {
const Type tvm_index_type = UInt(32);
const Stmt nop = Evaluate::make(0);
// Data field definitions
// The packed fields
Var v_packed_args("args", Handle());
Var v_packed_arg_type_ids("arg_type_ids", Handle());
Var v_num_packed_args("num_args", Int(32));
// The arguments of the function.
Array<Var> args;
// seq_init gives sequence of initialization
// seq_check gives sequence of later checks after iniit
std::vector<Stmt> seq_init, seq_check;
std::unordered_set<const Variable*> visited;
// the handle data types
Map<Var, Expr> handle_data_type;
// ---------------------------
// local function defintiions
// load i-th argument as type t
auto f_arg_value = [&](Type t, int i) {
Array<Expr> call_args{
v_packed_args, v_packed_arg_type_ids, IntImm::make(Int(32), i)};
return Call::make(
t, intrinsic::tvm_api_load_arg, call_args,
Call::PureIntrinsic);
};
// get declaration of argument i
auto f_arg_decl = [&](int i) {
std::ostringstream os;
os << "arg" << i;
const Variable* v = api_args[i].as<Variable>();
return Var(os.str(), v ? v->type: Handle());
};
// Push related into assertions or variable defintion
// given the symbolic declaration and concrete value
auto f_push = [&](Expr sym, Expr value, std::string field) {
if (sym.as<Variable>()) {
// If sym is a Variable and this Variable is not yet defined
// add this to defintion.
Var v(sym.node_);
if (!visited.count(v.get())) {
seq_init.emplace_back(LetStmt::make(v, value, nop));
visited.insert(v.get());
return true;
}
}
// otherwise, assume sym is already defined, insert assertion.
std::ostringstream os;
os << "Field " << field << " has a unsatisfied constraint";
seq_check.emplace_back(MakeAssertEQ(sym, value, os.str()));
return false;
};
// ---------------------------
// start of logics
// add signiture for packed arguments.
if (num_packed_args != 0) {
args.push_back(v_packed_args);
args.push_back(v_packed_arg_type_ids);
args.push_back(v_num_packed_args);
std::ostringstream os;
os << "expected num_args to be " << num_packed_args;
seq_init.emplace_back(
MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
}
for (size_t i = 0; i < api_args.size(); ++i) {
Var v_arg = f_arg_decl(i);
if (i < static_cast<size_t>(num_packed_args)) {
seq_init.emplace_back(LetStmt::make(
v_arg, f_arg_value(v_arg.type(), i), nop));
} else {
args.push_back(v_arg);
}
// add checks for functions.
if (api_args[i].as<Variable>()) {
f_push(Var(api_args[i].node_), v_arg, v_arg->name_hint);
} else {
// Buffer checks
CHECK(api_args[i].as<BufferNode>())
<< "api_args can only be Buffer or Var";
Buffer buf(api_args[i].node_);
// dimension checks
Expr v_ndim = TVMArrayGet(tvm_index_type, v_arg, intrinsic::kNDim);
std::ostringstream ndim_err_msg;
ndim_err_msg << "arg_" << i
<< ".ndim is expected to equal "
<< buf->shape.size();
seq_init.emplace_back(
MakeAssertEQ(v_ndim, UIntImm::make(tvm_index_type, buf->shape.size()),
ndim_err_msg.str()));
// type checks
Type dtype = buf->dtype;
std::ostringstream type_err_msg;
type_err_msg << "arg" << i << ".dtype is expected to be " << dtype;
Expr cond = (TVMArrayGet(UInt(8), v_arg, intrinsic::kTypeCode) ==
UIntImm::make(UInt(8), dtype.code()) &&
TVMArrayGet(UInt(8), v_arg, intrinsic::kTypeBits) ==
UIntImm::make(UInt(8), dtype.bits()) &&
TVMArrayGet(UInt(16), v_arg, intrinsic::kTypeLanes) ==
UIntImm::make(UInt(16), dtype.lanes()));
seq_init.emplace_back(AssertStmt::make(cond, type_err_msg.str()));
// Data Field
if (f_push(buf->ptr, TVMArrayGet(Handle(), v_arg, intrinsic::kData),
v_arg->name_hint + ".data")) {
Var vptr(buf->ptr);
handle_data_type.Set(vptr, make_const(buf->dtype, 0));
}
// shape field
Var v_shape(v_arg->name_hint + ".shape", Handle());
handle_data_type.Set(v_shape, UIntImm::make(tvm_index_type, 0));
seq_init.emplace_back(LetStmt::make(
v_shape, TVMArrayGet(Handle(), v_arg, intrinsic::kShape), nop));
for (size_t k = 0; k < buf->shape.size(); ++k) {
std::ostringstream field_name;
field_name << v_shape->name_hint << '[' << k << ']';
f_push(buf->shape[k],
cast(buf->shape[k].type(),
Load::make(tvm_index_type, v_shape, IntImm::make(Int(32), k))),
field_name.str());
}
// strides field
Var v_strides(v_arg->name_hint + ".strides", Handle());
handle_data_type.Set(v_strides, UIntImm::make(tvm_index_type, 0));
seq_init.emplace_back(LetStmt::make(
v_strides, TVMArrayGet(Handle(), v_arg, intrinsic::kStrides), nop));
if (buf->strides.size() == 0) {
std::ostringstream stride_err_msg;
stride_err_msg << "arg_" << i << ".strides:"
<< " expected to be nullptr for contiguous array";
seq_init.emplace_back(AssertNull(v_strides, stride_err_msg.str()));
} else {
for (size_t k = 0; k < buf->strides.size(); ++k) {
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';
f_push(buf->strides[k],
cast(buf->shape[k].type(),
Load::make(tvm_index_type, v_strides, IntImm::make(Int(32), k))),
field_name.str());
}
}
}
}
std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>();
n->name = name;
n->args = args;
n->handle_data_type = handle_data_type;
n->body = MergeNest({seq_init, seq_check}, body);
LoweredFunc f(n);
Array<Var> undefined = UndefinedVars(f);
if (undefined.size() != 0) {
std::ostringstream os;
for (Var v : undefined) {
os << " \'" << v->name_hint << "\' ";
}
os << " does not appeared in api_args";
LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
}
return f;
}
} // namespace codegen
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file split_host_device.cc
* \brief Split device function from host.
*/
#include <tvm/codegen.h>
#include <tvm/ir.h>
#include <tvm/module.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <unordered_map>
namespace tvm {
namespace codegen {
using namespace ir;
// use/def analysis, also delete unreferenced lets
class IRUseDefAnalysis : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->type_key == "thread_extent") {
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
// thread_extent can appear multiple times
// use the first appearance as def.
if (!use_count_.count(iv->var.get())) {
this->HandleDef(iv->var.get());
thread_axis_.push_back(iv);
thread_extent_.push_back(op->value);
}
Expr value = op->value;
if (visit_thread_extent_) {
value = this->Mutate(value);
}
Stmt body = this->Mutate(op->body);
if (value.same_as(value) && body.same_as(body)) return s;
return AttrStmt::make(op->node, op->type_key, value, body);
} else {
return IRMutator::Mutate_(op, s);
}
}
Stmt Mutate_(const LetStmt *op, const Stmt& s) final {
this->HandleDef(op->var.get());
Stmt body = this->Mutate(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 &&
!HasSideEffect(op->value)) {
return body;
} else {
Expr value = this->Mutate(op->value);
if (body.same_as(op->body) &&
value.same_as(op->value)) {
return s;
} else {
return LetStmt::make(op->var, value, body);
}
}
}
Stmt Mutate_(const For *op, const Stmt& s) final {
this->HandleDef(op->loop_var.get());
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Allocate *op, const Stmt& s) final {
this->HandleDef(op->buffer_var.get());
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Store *op, const Stmt& s) final {
this->HandleUse(op->buffer_var);
return IRMutator::Mutate_(op, s);
}
Expr Mutate_(const Let *op, const Expr& e) final {
this->HandleDef(op->var.get());
Expr body = this->Mutate(op->body);
// eliminate unreferenced let
if (use_count_.at(op->var.get()) == 0 &&
!HasSideEffect(op->value)) {
return body;
} else {
Expr value = this->Mutate(op->value);
if (body.same_as(op->body) &&
value.same_as(op->value)) {
return e;
} else {
return Let::make(op->var, value, body);
}
}
}
Expr Mutate_(const Variable *op, const Expr& e) final {
this->HandleUse(e);
return IRMutator::Mutate_(op, e);
}
Expr Mutate_(const Load *op, const Expr& e) final {
this->HandleUse(op->buffer_var);
return IRMutator::Mutate_(op, e);
}
void HandleDef(const Variable* v) {
CHECK(!use_count_.count(v))
<< "variable is already defined";
use_count_[v] = 0;
}
void HandleUse(const Expr& v) {
CHECK(v.as<Variable>());
Var var(v.node_);
auto it = use_count_.find(var.get());
if (it != use_count_.end()) {
if (it->second >= 0) {
++it->second;
}
} else {
undefined_.push_back(var);
use_count_[var.get()] = -1;
}
}
// The fields are publically readible to
// be accessible to the users.
bool visit_thread_extent_{true};
Array<Var> undefined_;
Array<IterVar> thread_axis_;
Array<Expr> thread_extent_;
std::unordered_map<const Variable*, int> use_count_;
};
class HostDeviceSplitter : public IRMutator {
public:
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->type_key == "thread_extent") {
LOG(INFO) << "??";
IterVar iv(op->node.node_);
return SplitDeviceFunc(s);
}
return IRMutator::Mutate_(op, s);
}
Array<LoweredFunc> Split(LoweredFunc f) {
for (auto kv : f->handle_data_type) {
handle_data_type_[kv.first.get()] = kv.second;
}
name_ = f->name;
std::shared_ptr<LoweredFuncNode> n =
std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = this->Mutate(f->body);
Array<LoweredFunc> ret{LoweredFunc(n)};
for (LoweredFunc x : device_funcs_) {
ret.push_back(x);
}
return ret;
}
private:
Stmt SplitDeviceFunc(Stmt body) {
std::ostringstream os;
os << name_ << "_kernel" << device_funcs_.size();
std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>();
// isolate the device function.
IRUseDefAnalysis m;
m.visit_thread_extent_ = false;
n->body = m.Mutate(body);
n->name = os.str();
n->args = m.undefined_;
CHECK_NE(m.thread_extent_.size(), 0U);
// improve the handle data type
for (Var arg : n->args) {
auto it = handle_data_type_.find(arg.get());
if (it != handle_data_type_.end()) {
n->handle_data_type.Set(arg, it->second);
}
}
LoweredFunc f_device(n);
Array<Expr> call_args;
for (Var arg : n->args) {
call_args.push_back(arg);
}
for (Expr ext : m.thread_extent_) {
call_args.push_back(ext);
}
device_funcs_.emplace_back(f_device);
return Evaluate::make(Call::make(
Int(32), f_device->name, call_args, Call::Extern, f_device));
}
// function name
std::string name_;
// the device functions
std::vector<LoweredFunc> device_funcs_;
std::unordered_map<const Variable*, Expr> handle_data_type_;
};
Array<Var> UndefinedVars(const LoweredFunc& f) {
IRUseDefAnalysis m;
for (Var arg : f->args) {
m.use_count_[arg.get()] = 0;
}
m.Mutate(f->body);
return m.undefined_;
}
Array<LoweredFunc> SplitHostDevice(LoweredFunc func) {
return HostDeviceSplitter().Split(func);
}
} // namespace codegen
} // namespace tvm
......@@ -17,36 +17,28 @@ class IRInline : public IRMutator {
IRInline(FunctionRef f, Array<Var> args, Expr body)
: f_(f), args_(args), body_(body) {}
Expr Mutate(Expr expr) final {
expr = IRMutator::Mutate(expr);
const Call* call = expr.as<Call>();
if (call != nullptr && call->func == f_) {
CHECK_EQ(call->value_index, 0);
return InlineCall(call);
} else {
return expr;
}
}
Stmt Mutate(Stmt stmt) final {
return IRMutator::Mutate(stmt);
}
Expr Mutate_(const Call* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
private:
FunctionRef f_;
Array<Var> args_;
Expr body_;
Expr InlineCall(const Call* op) {
if (op->func == f_) {
CHECK_EQ(op->value_index, 0);
Expr expr = body_;
CHECK_EQ(args_.size(), op->args.size())
<< op->args.size() << " vs " << args_.size();
for (size_t i = 0; i < args_.size(); ++i) {
expr = Let::make(args_[i], op->args[i], expr);
}
return expr;
} else {
return e;
}
}
private:
FunctionRef f_;
Array<Var> args_;
Expr body_;
};
Stmt Inline(Stmt stmt,
......
/*!
* Copyright (c) 2016 by Contributors
* \file ir_util.h
* \brief Helper functions to construct and compose IR nodes.
*/
#ifndef TVM_PASS_IR_UTIL_H_
#define TVM_PASS_IR_UTIL_H_
#include <tvm/ir.h>
#include <vector>
namespace tvm {
namespace ir {
/*!
* \brief combine the nest stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body body
* \return The combined Stmt
*/
inline Stmt MergeNest(std::vector<Stmt> nest, Stmt body) {
// use reverse iteration
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
Stmt s = *ri;
if (s.as<For>()) {
auto n = std::make_shared<For>(*s.as<For>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<LetStmt>()) {
auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<AttrStmt>()) {
auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<IfThenElse>()) {
auto n = std::make_shared<IfThenElse>(*s.as<IfThenElse>());
CHECK(is_no_op(n->then_case));
CHECK(!n->else_case.defined());
n->then_case = body;
body = Stmt(n);
} else if (s.as<AssertStmt>()) {
body = Block::make(s, body);
} else {
LOG(FATAL) << "not supported nest type";
}
}
return body;
}
/*!
* \brief combine the nest stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body body
* \return The combined Stmt
*/
inline Stmt MergeNest(std::vector<std::vector<Stmt> > nest, Stmt body) {
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
body = MergeNest(*ri, body);
}
return body;
}
} // namespace ir
} // namespace tvm
#endif // TVM_PASS_IR_UTIL_H_
......@@ -8,7 +8,6 @@
namespace tvm {
namespace ir {
namespace {
// visitor to implement apply
class IRApplyVisit : public IRVisitor {
public:
......@@ -26,7 +25,6 @@ class IRApplyVisit : public IRVisitor {
std::unordered_set<const Node*> visited_;
};
} // namespace
void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit) {
IRApplyVisit(fvisit).Visit(node);
......@@ -36,12 +34,6 @@ IRVisitor::FVisit& IRVisitor::vtable() { // NOLINT(*)
static FVisit inst; return inst;
}
// namespace to register the functors.
namespace {
using namespace Halide::Internal;
void NoOp(const NodeRef& n, IRVisitor* v) {
}
......@@ -59,24 +51,82 @@ inline void VisitRDom(const Array<IterVar>& rdom, IRVisitor* v) {
}
}
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
VisitRDom(op->rdom, v);
v->Visit(op->source);
});
#define DISPATCH_TO_VISIT(OP) \
set_dispatch<OP>([](const OP* op, IRVisitor* v) { \
v->Visit_(op); \
})
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt* op, IRVisitor* v) {
v->Visit(op->value);
.DISPATCH_TO_VISIT(Variable)
.DISPATCH_TO_VISIT(LetStmt)
.DISPATCH_TO_VISIT(For)
.DISPATCH_TO_VISIT(Allocate)
.DISPATCH_TO_VISIT(Load)
.DISPATCH_TO_VISIT(Store)
.DISPATCH_TO_VISIT(Let)
.DISPATCH_TO_VISIT(Call)
.DISPATCH_TO_VISIT(Free);
void IRVisitor::Visit_(const Variable* op) {}
void IRVisitor::Visit_(const LetStmt *op) {
this->Visit(op->value);
this->Visit(op->body);
}
void IRVisitor::Visit_(const AttrStmt* op) {
this->Visit(op->value);
this->Visit(op->body);
}
void IRVisitor::Visit_(const For *op) {
IRVisitor* v = this;
v->Visit(op->min);
v->Visit(op->extent);
v->Visit(op->body);
});
}
void IRVisitor::Visit_(const Allocate *op) {
IRVisitor* v = this;
for (size_t i = 0; i < op->extents.size(); i++) {
v->Visit(op->extents[i]);
}
v->Visit(op->body);
v->Visit(op->condition);
if (op->new_expr.defined()) {
v->Visit(op->new_expr);
}
}
void IRVisitor::Visit_(const Load *op) {
this->Visit(op->index);
}
void IRVisitor::Visit_(const Store *op) {
this->Visit(op->value);
this->Visit(op->index);
}
void IRVisitor::Visit_(const Let *op) {
this->Visit(op->value);
this->Visit(op->body);
}
void IRVisitor::Visit_(const Free* op) {}
void IRVisitor::Visit_(const Call *op) {
VisitArray(op->args, this);
}
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
VisitRDom(op->rdom, v);
v->Visit(op->source);
})
.set_dispatch<IntImm>(NoOp)
.set_dispatch<UIntImm>(NoOp)
.set_dispatch<FloatImm>(NoOp)
.set_dispatch<StringImm>(NoOp)
.set_dispatch<Variable>(NoOp);
.set_dispatch<StringImm>(NoOp);
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<Cast>([](const Cast* op, IRVisitor* v) {
......@@ -116,29 +166,15 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
v->Visit(op->true_value);
v->Visit(op->false_value);
})
.set_dispatch<Load>([](const Load *op, IRVisitor* v) {
v->Visit(op->index);
})
.set_dispatch<Ramp>([](const Ramp *op, IRVisitor* v) {
v->Visit(op->base);
v->Visit(op->stride);
})
.set_dispatch<Broadcast>([](const Broadcast *op, IRVisitor* v) {
v->Visit(op->value);
})
.set_dispatch<Call>([](const Call *op, IRVisitor* v) {
VisitArray(op->args, v);
})
.set_dispatch<Let>([](const Let *op, IRVisitor* v) {
v->Visit(op->value);
v->Visit(op->body);
});
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<LetStmt>([](const LetStmt *op, IRVisitor* v) {
v->Visit(op->value);
v->Visit(op->body);
})
.set_dispatch<AssertStmt>([](const AssertStmt *op, IRVisitor* v) {
v->Visit(op->condition);
v->Visit(op->message);
......@@ -146,30 +182,10 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, IRVisitor* v) {
v->Visit(op->body);
})
.set_dispatch<For>([](const For *op, IRVisitor* v) {
v->Visit(op->min);
v->Visit(op->extent);
v->Visit(op->body);
})
.set_dispatch<Store>([](const Store *op, IRVisitor* v) {
v->Visit(op->value);
v->Visit(op->index);
})
.set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
VisitArray(op->args, v);
v->Visit(op->value);
})
.set_dispatch<Allocate>([](const Allocate *op, IRVisitor* v) {
for (size_t i = 0; i < op->extents.size(); i++) {
v->Visit(op->extents[i]);
}
v->Visit(op->body);
v->Visit(op->condition);
if (op->new_expr.defined()) {
v->Visit(op->new_expr);
}
})
.set_dispatch<Free>(NoOp)
.set_dispatch<Realize>([](const Realize *op, IRVisitor* v) {
// Mutate the bounds
for (size_t i = 0; i < op->bounds.size(); i++) {
......@@ -193,6 +209,5 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
v->Visit(op->value);
});
} // namespace
} // namespace ir
} // namespace tvm
......@@ -9,6 +9,7 @@
#include <tvm/schedule_pass.h>
#include "./scope.h"
#include "./ir_util.h"
#include "../schedule/graph.h"
namespace tvm {
......@@ -32,18 +33,27 @@ void PassUpOffset(const Stage& s,
Expr outer = state.at(s->outer);
Expr inner = state.at(s->inner);
Expr factor = dom_map.at(s->inner)->extent;
Expr offset = inner + outer * factor;
Expr outer_min = dom_map.at(s->parent)->min;
if (!is_zero(outer_min)) {
offset = outer_min + offset;
Expr parent_min = dom_map.at(s->parent)->min;
state[s->parent] = inner + outer * factor;
// add min if they exist
if (!is_zero(parent_min)) {
state[s->parent] = parent_min + state[s->parent];
}
state[s->parent] = offset;
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
Expr value = state.at(s->fused);
Expr factor = dom_map.at(s->inner)->extent;
Expr outer_min = dom_map.at(s->outer)->min;
Expr inner_min = dom_map.at(s->inner)->min;
state[s->outer] = value / factor;
state[s->inner] = value % factor;
// add min if they exist
if (!is_zero(outer_min)) {
state[s->outer] = outer_min + state[s->outer];
}
if (!is_zero(inner_min)) {
state[s->inner] = outer_min + state[s->inner];
}
} else {
LOG(FATAL) << "unknown relation type";
}
......@@ -82,45 +92,6 @@ void SplitByAdd(Expr expr,
}
/*!
* \brief combine the nest stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body body
*/
Stmt MergeNest(std::vector<std::vector<Stmt> > nest, Stmt body) {
// use reverse iteration
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
for (auto rj = ri->rbegin(); rj != ri->rend(); ++rj) {
Stmt s = *rj;
if (s.as<For>()) {
auto n = std::make_shared<For>(*s.as<For>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<LetStmt>()) {
auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<AttrStmt>()) {
auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<IfThenElse>()) {
auto n = std::make_shared<IfThenElse>(*s.as<IfThenElse>());
CHECK(is_no_op(n->then_case));
CHECK(!n->else_case.defined());
n->then_case = body;
body = Stmt(n);
} else {
LOG(FATAL) << "not supported nest type";
}
}
}
return body;
}
/*!
* \brief Make the loop nest of the correspondings schedule.
* \param sch The schedule.
* \param dom_map The domain map.
......@@ -142,16 +113,32 @@ std::vector<std::vector<Stmt> > MakeLoopNest(
for (size_t i = 0; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
Range dom = dom_map.at(iv);
// initialize the offset and loop_level
offset[iv] = iv->var;
loop_level[iv->var.as<Variable>()] = i + 1;
// Mark the iter var in the IR, to remember the point
if (iv->thread_tag.length() == 0) {
Range dom = dom_map.at(iv);
if (is_zero(dom->min)) {
nest[i + 1].emplace_back(
For::make(iv->var, dom->min, dom->extent,
For::make(iv->var, 0, dom->extent,
ForType::Serial, DeviceAPI::None, no_op));
} else {
Var idx(iv->var->name_hint + ".idx", iv->var.type());
nest[i + 1].emplace_back(
For::make(idx, 0, dom->extent,
ForType::Serial, DeviceAPI::None, no_op));
nest[i + 1].emplace_back(
LetStmt::make(iv->var, dom->min + idx, no_op));
}
} else {
// Always restrict threaded IterVar to starts from 0.
CHECK(is_zero(dom->min));
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, "thread_extent", dom->extent, no_op));
}
// annotate the extent of the IterVar
nest[i + 1].emplace_back(
AttrStmt::make(iv, "scope", iv->var, no_op));
}
......
/*!
* Copyright (c) 2016 by Contributors
* \file simple_passes.cc
* \brief Implementation of simple passes
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
namespace tvm {
namespace ir {
class IRSideEffect : public IRVisitor {
public:
void Visit(const NodeRef& e) final {
if (has_side_effect_) return;
}
void Visit_(const Call* op) final {
if (!op->is_pure()) {
has_side_effect_ = true; return;
} else {
IRVisitor::Visit_(op);
}
}
bool has_side_effect_{false};
};
bool HasSideEffect(const Expr& e) {
IRSideEffect v;
v.Visit(e);
return v.has_side_effect_;
}
} // namespace ir
} // namespace tvm
......@@ -24,31 +24,15 @@ def mock_test_add():
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
stmt = tvm.ir_pass.Simplify(stmt)
print(stmt)
output_ssa = False
code = tvm.codegen.CompileToC(stmt, "myadd",
[Ab.ptr, Bb.ptr, Cb.ptr, n],
output_ssa)
f = tvm.codegen.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 1)
f_list = tvm.codegen.SplitHostDevice(f)
for x in f_list:
code = tvm.codegen.CompileToC(x, output_ssa)
print(code)
def codegen():
# generate host/device code
host_code, device_code = tvm.codegen.GenCUDA(
s,
inputs={A: Ab, B:Bb},
outputs={C: Cb},
args=[A, B, C])
# generate a function based on the code
f = tvm.cuda.build_function(host_code, device_code)
# create arrays
a = tvm.nd.array(np.ones(10), ctx=tvm.gpu(0))
b = tvm.nd.array(np.ones(10), ctx=tvm.gpu(0))
c = tvm.nd.array(np.zeros(10), ctx=tvm.gpu(0))
# calll the generated code
f(a, b, c)
# sync the result
np.testing.assert_equal(c.asnumpy(), np.ones(10) * 2)
if __name__ == "__main__":
mock_test_add()
import tvm
import numpy
def test_makeapi():
"""Not yet working, mock design"""
n = tvm.Var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.Schedule(C.op)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.ir_pass.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
num_packed_args = 2
f = tvm.codegen.MakeAPI(stmt, "myadd", [n, Ab, Bb, Cb], num_packed_args)
assert(f.handle_data_type[Ab.ptr].dtype == Ab.dtype)
assert(len(f.args) == 5)
output_ssa = False
if __name__ == "__main__":
test_makeapi()
......@@ -18,6 +18,7 @@ def test_flatten2():
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
stmt = tvm.ir_pass.Simplify(stmt)
print(stmt)
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment