Commit d89917b6 by Tianqi Chen Committed by GitHub

[PASS] StorageFlatten and StorageSync, safe condition in schedul_ops, gemm example. (#31)

parent a2c8a29b
......@@ -22,54 +22,6 @@ using runtime::TVMArgs;
using runtime::TVMRetValue;
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
* - Map the values in the api_args to of Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_packed_args Number of arguments that are processed in packed form.
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n)
*
* where n == len(api_args), k == num_packed_args
*
* There is no thread_axis in generated function.
*/
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args);
/*!
* \brief Count number of undefined vars in f.
* \param f The function to be checked.
* \return Number of undefined vars.
*/
Array<Var> UndefinedVars(const LoweredFunc& f);
/*!
* \brief Split the function into a host function and device functions.
* \param func The function to be splitted.
*
* \return Array of functions, the first one is host function,
* the others are device functions.
*/
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
/*!
* \brief Build a stack VM function.
* \param func The LoweredFunc to be build
* \param device_funcs The additional device functions
......
......@@ -88,6 +88,25 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
* }
*/
constexpr const char* tvm_call_global = "tvm_call_global";
/*!
* \brief See pesudo code
*
* int tvm_call_device(name, TVMValue* args) {
* PackedFunc df = CodeGenEnv->GetDevice(name);
* f (args, type_code_of(args), len(args));
* return 0;
* }
*/
constexpr const char* tvm_call_device = "tvm_call_device";
/*!
* \brief See pesudo code
*
* int tvm_storage_sync(std::string storage_scope) {
* __sync(storage_scope);
* return 0;
* }
*/
constexpr const char* tvm_storage_sync = "tvm_storage_sync";
/*! \brief The field id of each field in array */
enum TVMArrayFieldKind {
......
......@@ -14,9 +14,11 @@
#include <tvm/ir_functor.h>
#include <unordered_map>
#include <vector>
#include <string>
#include "./expr.h"
#include "./buffer.h"
#include "./schedule.h"
#include "./lowered_func.h"
namespace tvm {
namespace ir {
......@@ -95,6 +97,62 @@ Stmt Inline(Stmt stmt,
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer);
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
* - Map the values in the api_args to of Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_packed_args Number of arguments that are processed in packed form.
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n)
*
* where n == len(api_args), k == num_packed_args
*
* There is no thread_axis in generated function.
*/
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args);
/*!
* \brief Count number of undefined vars in f.
* \param f The function to be checked.
* \return Number of undefined vars.
*/
Array<Var> UndefinedVars(const LoweredFunc& f);
/*!
* \brief Split the function into a host function and device functions.
* \param func The function to be splitted.
*
* \return Array of functions, the first one is host function,
* the others are device functions.
*/
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
/*!
* \brief Insert sync between parallel read/write of shared buffers.
*
* \param stmt The stmt to be trasnformed.
* \param storage_scope The storage scope considered.
*/
LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);
} // namespace ir
} // namespace tvm
......
......@@ -40,6 +40,7 @@ class IRVisitor {
virtual void Visit_(const LetStmt* op);
virtual void Visit_(const For* op);
virtual void Visit_(const Allocate* op);
virtual void Visit_(const IfThenElse* op);
virtual void Visit_(const Load* op);
virtual void Visit_(const Store* op);
virtual void Visit_(const Let* op);
......
......@@ -65,17 +65,19 @@ def build(sch,
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.Simplify(stmt)
print(stmt)
fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list))
fsplits = codegen.SplitHostDevice(fapi)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list))
fsplits = ir_pass.SplitHostDevice(fapi)
fsplits = [x for x in fsplits]
for i in range(1, len(fsplits)):
fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared")
fsplits[i] = ir_pass.StorageSync(fsplits[i], "global")
if record_codes is not None:
output_ssa = False
for i, f in enumerate(fsplits):
t = target if i >= 1 else "c"
record_codes.append(codegen.CompileToC(f, output_ssa, t))
for c in record_codes:
print(c)
if target == "cuda":
ret = codegen.BuildNVRTC(fsplits, "stackvm")
elif target == "opencl":
......
......@@ -31,18 +31,6 @@ TVM_REGISTER_API(_codegen_CompileToC)
}
});
TVM_REGISTER_API(_codegen_MakeAPI)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = MakeAPI(
args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_API(_codegen_SplitHostDevice)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = SplitHostDevice(args[0]);
});
TVM_REGISTER_API(_codegen_BuildStackVM)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = BuildStackVM(args[0],
......
......@@ -52,6 +52,9 @@ REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
REGISTER_PASS1(SplitHostDevice);
} // namespace ir
} // namespace tvm
......@@ -20,7 +20,6 @@ std::string CodeGenC::Compile(LoweredFunc f,
HandleTypeRegister(kv.first.get(), kv.second.type());
}
this->indent += 2;
this->stream << "void " << f->name << "(";
for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i];
......@@ -38,8 +37,9 @@ std::string CodeGenC::Compile(LoweredFunc f,
stream << ' ' << vid;
}
stream << ") {\n";
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
this->indent -= 2;
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n";
return stream.str();
......@@ -54,19 +54,23 @@ std::string CodeGenC::SSAGetID(std::string src, Type t) {
if (name_alloc_map_.count(src)) return src;
auto it = ssa_assign_map_.find(src);
if (it != ssa_assign_map_.end()) {
return it->second;
} else {
if (scope_mark_.at(it->second.scope_id)) {
return it->second.vid;
}
}
this->PrintIndent();
std::string id = GetUniqueName("_");
ssa_assign_map_[src] = id;
SSAEntry e;
e.vid = GetUniqueName("_");
e.scope_id = static_cast<int>(scope_mark_.size() - 1);
ssa_assign_map_[src] = e;
if (src.length() > 3 &&
src[0] == '(' && src[src.length() - 1] == ')') {
src = src.substr(1, src.length() - 2);
}
PrintType(t, stream);
stream << ' ' << id << " = " << src << ";\n";
return id;
}
stream << ' ' << e.vid << " = " << src << ";\n";
return e.vid;
}
void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*)
......@@ -142,9 +146,12 @@ void CodeGenC::MarkConst(std::string vid) {
if (print_ssa_form_) {
auto it = ssa_assign_map_.find(vid);
if (it == ssa_assign_map_.end()) {
ssa_assign_map_[vid] = vid;
SSAEntry e;
e.vid = vid;
e.scope_id = 0;
ssa_assign_map_[vid] = e;
} else {
CHECK_EQ(it->second, vid);
CHECK_EQ(it->second.vid, vid);
}
}
}
......@@ -242,6 +249,9 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
})
.set_dispatch<FloatImm>([](const FloatImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*)
PrintConst(op, os, p);
})
.set_dispatch<StringImm>([](const StringImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*)
os << "\"" << op->value << "\"";
});
template<typename T>
......@@ -340,49 +350,22 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt)
.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, CodeGenC* p) {
p->PrintStmt(op->body);
})
.set_dispatch<For>([](const For *op, CodeGenC* p) {
std::string extent = p->PrintExpr(op->extent);
p->PrintIndent();
std::string vid = p->AllocVarID(op->loop_var.get());
CHECK(is_zero(op->min));
p->stream << "for (";
p->PrintType(op->loop_var.type(), p->stream);
p->stream << ' ' << vid << " = 0; "
<< vid << " < " << extent
<< "; ++" << vid << ") {\n";
p->indent += 2;
p->PrintStmt(op->body);
p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
})
.set_dispatch<Block>([](const Block *op, CodeGenC* p) {
p->PrintStmt(op->first);
if (op->rest.defined()) p->PrintStmt(op->rest);
})
.set_dispatch<Evaluate>([](const Evaluate *op, CodeGenC* p) {
if (is_const(op->value)) return;
const Call* call = op->value.as<Call>();
if (call && call->is_intrinsic(intrinsic::tvm_storage_sync)) {
p->PrintStorageSync(call->args[0].as<StringImm>()->value);
} else {
std::string vid = p->PrintExpr(op->value);
p->PrintIndent();
p->stream << "(void)" << vid << ";\n";
})
.set_dispatch<IfThenElse>([](const IfThenElse *op, CodeGenC* p) {
std::string cond = p->PrintExpr(op->condition);
p->PrintIndent();
p->stream << "if (" << cond << ") {\n";
p->indent += 2;
p->PrintStmt(op->then_case);
p->indent -= 2;
if (op->else_case.defined()) {
p->PrintIndent();
p->stream << "} else {\n";
p->indent += 2;
p->PrintStmt(op->else_case);
p->indent -= 2;
}
p->PrintIndent();
p->stream << "}\n";
});
});
#define DISPATCH_EXPR(OP) \
......@@ -517,13 +500,22 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt)
.set_dispatch<Store>([](const Store *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<Allocate>([](const Allocate *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<AttrStmt>([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<AssertStmt>([](const AssertStmt *op, CodeGenC* p) { p->PrintStmt(op); });
.set_dispatch<AssertStmt>([](const AssertStmt *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<For>([](const For *op, CodeGenC* p) { p->PrintStmt(op); })
.set_dispatch<IfThenElse>([](const IfThenElse *op, CodeGenC* p) { p->PrintStmt(op); });
void CodeGenC::PrintThreadTagExpr(
std::string thread_tag, std::ostream& os) const { // NOLINT(*)
void CodeGenC::PrintThreadIndexExpr(
std::string thread_tag, std::ostream& os) { // NOLINT(*)
os << thread_tag;
}
void CodeGenC::PrintStorageSync(const std::string& sync) { // NOLINT(*)
}
void CodeGenC::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
CHECK_EQ(scope, "global");
}
void CodeGenC::PrintStmt(const LetStmt* op) {
std::string value = PrintExpr(op->value);
if (print_ssa_form_) {
......@@ -581,9 +573,12 @@ void CodeGenC::PrintStmt(const Allocate* op) {
int32_t constant_size = op->constant_allocation_size();
CHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
const Variable* buffer = op->buffer_var.as<Variable>();
std::string scope = alloc_storage_scope_.at(buffer);
PrintStorageScope(scope, stream);
PrintType(op->type, stream);
stream << ' '<< vid << '['
<< constant_size << "]\n;";
<< constant_size << "];\n";
}
HandleTypeRegister(op->buffer_var.get(), op->type);
this->PrintStmt(op->body);
......@@ -599,10 +594,14 @@ void CodeGenC::PrintStmt(const AttrStmt* op) {
stream << ' '
<< AllocVarID(iv->var.get())
<< " = ";
PrintThreadTagExpr(iv->thread_tag, stream);
PrintThreadIndexExpr(iv->thread_tag, stream);
stream << ";\n";
}
}
} else if (op->type_key == "storage_scope") {
const Variable* v = op->node.as<Variable>();
CHECK(v);
alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
}
this->PrintStmt(op->body);
}
......@@ -619,5 +618,54 @@ void CodeGenC::PrintStmt(const AssertStmt* op) {
}
}
int CodeGenC::BeginScope() {
int sid = static_cast<int>(scope_mark_.size());
scope_mark_.push_back(true);
indent += 2;
return sid;
}
void CodeGenC::EndScope(int scope_id) {
scope_mark_[scope_id] = false;
indent -= 2;
}
void CodeGenC::PrintStmt(const For* op) {
std::string extent = PrintExpr(op->extent);
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
CHECK(is_zero(op->min));
stream << "for (";
PrintType(op->loop_var.type(), stream);
stream << ' ' << vid << " = 0; "
<< vid << " < " << extent
<< "; ++" << vid << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
PrintIndent();
stream << "}\n";
}
void CodeGenC::PrintStmt(const IfThenElse* op) {
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if (" << cond << ") {\n";
int then_scope = BeginScope();
PrintStmt(op->then_case);
this->EndScope(then_scope);
if (op->else_case.defined()) {
PrintIndent();
stream << "} else {\n";
int else_scope = BeginScope();
PrintStmt(op->else_case);
this->EndScope(else_scope);
}
PrintIndent();
stream << "}\n";
}
} // namespace codegen
} // namespace tvm
......@@ -10,6 +10,7 @@
#include <tvm/codegen.h>
#include <tvm/lowered_func.h>
#include <string>
#include <vector>
#include <unordered_map>
namespace tvm {
......@@ -80,13 +81,18 @@ class CodeGenC {
virtual void PrintType(Type t, std::ostream& os) const; // NOLINT(*)
/*!
* \brief Print expr representing the thread tag
* \param thread_tag The tag in the thread.
* \param tag The tag in the thread.
* \param os The strean to output to
*/
virtual void PrintThreadTagExpr(
std::string thread_tag, std::ostream& os) const; // NOLINT(*)
virtual void PrintThreadIndexExpr(
std::string tag, std::ostream& os); // NOLINT(*)
virtual void PrintStorageScope(const std::string& scope, std::ostream& os); // NOLINT(*
virtual void PrintStorageSync(const std::string& scope); // NOLINT(*)
virtual void PrintStmt(const ir::LetStmt* op);
virtual void PrintStmt(const ir::Store* op);
virtual void PrintStmt(const ir::For* op);
virtual void PrintStmt(const ir::IfThenElse* op);
virtual void PrintStmt(const ir::Allocate* op);
virtual void PrintStmt(const ir::AttrStmt* op);
virtual void PrintStmt(const ir::AssertStmt* op);
......@@ -114,6 +120,13 @@ class CodeGenC {
std::string arg_addr_space_;
private:
/*! \brief entry in ssa assign map */
struct SSAEntry {
/*! \brief The value id */
std::string vid;
/*! \brief The scope id */
int scope_id;
};
/*!
* \brief Get the SSA ID corresponds to src
* If necessary, generate new assignment
......@@ -122,6 +135,16 @@ class CodeGenC {
*/
std::string SSAGetID(std::string src, Type t);
/*!
* \brief mark the beginning of a new scope
* \return The scope id.
*/
int BeginScope();
/*!
* \brief mark the end of an old scope.
* \param scope_id The scope id to be ended.
*/
void EndScope(int scope_id);
/*!
* \brief If buffer is allocated as type t.
* \param buf_var The buffer variable.
* \param t The type to be checked.
......@@ -145,10 +168,14 @@ class CodeGenC {
std::unordered_map<const Variable*, std::string> var_idmap_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, Type> handle_data_type_;
/*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief assignment map of ssa */
std::unordered_map<std::string, std::string> ssa_assign_map_;
std::unordered_map<std::string, SSAEntry> ssa_assign_map_;
/*! \brief array to check whether we are inside certain scope */
std::vector<bool> scope_mark_;
};
} // namespace codegen
......
......@@ -22,6 +22,23 @@ std::string CodeGenCUDA::Compile(
return CodeGenC::Compile(f, output_ssa);
}
void CodeGenCUDA::PrintStorageSync(const std::string& sync) {
if (sync == "shared") {
this->PrintIndent();
this->stream << "__syncthreads();\n";
} else if (sync == "global") {
LOG(FATAL) << "not supported";
}
}
void CodeGenCUDA::PrintStorageScope(
const std::string& scope, std::ostream& os) { // NOLINT(*)
CHECK_NE(scope, "global");
if (scope == "shared") {
os << "__shared__ ";
}
}
#if TVM_CUDA_RUNTIME
std::unordered_map<LoweredFunc, PackedFunc>
MakeNVRTC(Array<LoweredFunc> funcs) {
......@@ -56,7 +73,6 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
PackedFunc BuildNVRTC(Array<LoweredFunc> fsplits, std::string host_mode) {
Array<LoweredFunc> device_list(fsplits.begin() + 1, fsplits.end());
std::unordered_map<LoweredFunc, PackedFunc> device_funcs = MakeNVRTC(device_list);
if (host_mode == "stackvm") {
StackVM vm = codegen::CodeGenStackVM().Compile(fsplits[0], device_funcs);
auto f = [vm](TVMArgs args, TVMRetValue* rv) {
......
......@@ -25,6 +25,9 @@ class CodeGenCUDA : public CodeGenC {
*/
std::string Compile(LoweredFunc f,
bool output_ssa);
// override behavior
void PrintStorageSync(const std::string& sync) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
};
} // namespace codegen
......
......@@ -10,6 +10,7 @@
#include "./codegen_stack_vm.h"
#include "../runtime/opencl/opencl_common.h"
#include "../runtime/opencl/opencl_module.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace codegen {
......@@ -22,22 +23,31 @@ std::string CodeGenOpenCL::Compile(
return CodeGenC::Compile(f, output_ssa);
}
void CodeGenOpenCL::PrintThreadTagExpr(
std::string thread_tag, std::ostream& os) const { // NOLINT(*)
if (thread_tag == "threadIdx.x") {
os << "get_local_id(0)";
} else if (thread_tag == "threadIdx.y") {
os << "get_local_id(1)";
} else if (thread_tag == "threadIdx.z") {
os << "get_local_id(2)";
} else if (thread_tag == "blockIdx.x") {
os << "get_global_id(0) / get_local_size(0)";
} else if (thread_tag == "blockIdx.y") {
os << "get_global_id(1) / get_local_size(1)";
} else if (thread_tag == "blockIdx.z") {
os << "get_global_id(2) / get_local_size(2)";
void CodeGenOpenCL::PrintThreadIndexExpr(
std::string tag, std::ostream& os) { // NOLINT(*)
runtime::ThreadScope ts = runtime::ThreadScope::make(tag);
if (ts.rank == 1) {
os << "get_local_id(" << ts.dim_index << ")";
} else {
LOG(FATAL) << "unknown thread tag";
os << "get_global_id(" << ts.dim_index << ")"
<< " / get_local_size(" << ts.dim_index << ")";
}
}
void CodeGenOpenCL::PrintStorageSync(const std::string& sync) {
if (sync == "shared") {
this->PrintIndent();
this->stream << "barrier(CLK_LOCAL_MEM_FENCE);\n";
} else if (sync == "global") {
LOG(FATAL) << "not supported";
}
}
void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
CHECK_NE(scope, "global");
if (scope == "shared") {
os << "__local ";
}
}
......
......@@ -26,8 +26,10 @@ class CodeGenOpenCL : public CodeGenC {
std::string Compile(LoweredFunc f,
bool output_ssa);
// override print thread tag.
void PrintThreadTagExpr(
std::string thread_tag, std::ostream& os) const final; // NOLINT(*)
void PrintThreadIndexExpr(
std::string tag, std::ostream& os) final; // NOLINT(*)
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const std::string& scope) final; // NOLINT(*)
};
} // namespace codegen
......
......@@ -37,7 +37,7 @@ StackVM CodeGenStackVM::Compile(
for (const auto& kv : device_funcs) {
int fid = static_cast<int>(vm_.packed_func.size());
vm_.packed_func.push_back(kv.second);
device_fun_idmap_[kv.first] = fid;
device_fun_idmap_[kv.first->name] = fid;
}
this->Push(f->body);
return std::move(vm_);
......@@ -228,20 +228,19 @@ void CodeGenStackVM::Push_(const ir::Call* op) {
this->Push(op->args[0]);
this->PushOp(StackVM::PUSH_I64, 0);
this->PushOp(StackVM::EQ_I64);
} else if (op->call_type == Call::Extern && op->func.defined()) {
CHECK(op->func->is_type<LoweredFuncNode>());
LoweredFunc f(op->func.node_);
auto it = device_fun_idmap_.find(f);
} else if (op->is_intrinsic(intrinsic::tvm_call_device)) {
std::string func_name = op->args[0].as<StringImm>()->value;
auto it = device_fun_idmap_.find(func_name);
CHECK(it != device_fun_idmap_.end())
<< "Cannot find device function " << f->name;
<< "Cannot find device function " << func_name;
const int fid = it->second;
std::vector<int> arg_type_codes(op->args.size());
for (size_t i = 0; i < op->args.size(); ++i) {
std::vector<int> arg_type_codes;
for (size_t i = 1; i < op->args.size(); ++i) {
this->Push(op->args[i]);
Type t = op->args[i].type();
int lanes = t.lanes();
CHECK_EQ(lanes, 1);
arg_type_codes[i] = t.code();
arg_type_codes.push_back(t.code());
}
this->PushCallPacked(fid, arg_type_codes);
} else {
......
......@@ -110,7 +110,7 @@ class CodeGenStackVM {
/*! \brief id of each global function */
std::unordered_map<std::string, int> global_fun_idmap_;
/*! \brief id of device function */
std::unordered_map<LoweredFunc, int> device_fun_idmap_;
std::unordered_map<std::string, int> device_fun_idmap_;
};
} // namespace codegen
......
......@@ -75,6 +75,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.DISPATCH_TO_MUTATE_STMT(Realize)
.DISPATCH_TO_MUTATE_STMT(Store)
.DISPATCH_TO_MUTATE_STMT(For)
.DISPATCH_TO_MUTATE_STMT(Allocate)
.DISPATCH_TO_MUTATE_STMT(Free);
Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) {
......
......@@ -45,6 +45,11 @@ inline Stmt MergeNest(std::vector<Stmt> nest, Stmt body) {
body = Stmt(n);
} else if (s.as<AssertStmt>()) {
body = Block::make(s, body);
} else if (s.as<Allocate>()) {
auto n = std::make_shared<Allocate>(*s.as<Allocate>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else {
LOG(FATAL) << "not supported nest type";
}
......
......@@ -59,6 +59,8 @@ inline void VisitRDom(const Array<IterVar>& rdom, IRVisitor* v) {
TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
.DISPATCH_TO_VISIT(Variable)
.DISPATCH_TO_VISIT(LetStmt)
.DISPATCH_TO_VISIT(AttrStmt)
.DISPATCH_TO_VISIT(IfThenElse)
.DISPATCH_TO_VISIT(For)
.DISPATCH_TO_VISIT(Allocate)
.DISPATCH_TO_VISIT(Load)
......@@ -107,6 +109,14 @@ void IRVisitor::Visit_(const Store *op) {
this->Visit(op->index);
}
void IRVisitor::Visit_(const IfThenElse *op) {
this->Visit(op->condition);
this->Visit(op->then_case);
if (op->else_case.defined()) {
this->Visit(op->else_case);
}
}
void IRVisitor::Visit_(const Let *op) {
this->Visit(op->value);
this->Visit(op->body);
......@@ -200,11 +210,6 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
v->Visit(op->first);
v->Visit(op->rest);
})
.set_dispatch<IfThenElse>([](const IfThenElse *op, IRVisitor* v) {
v->Visit(op->condition);
v->Visit(op->then_case);
v->Visit(op->else_case);
})
.set_dispatch<Evaluate>([](const Evaluate *op, IRVisitor* v) {
v->Visit(op->value);
});
......
......@@ -2,7 +2,7 @@
* Copyright (c) 2017 by Contributors
* \file make_api.cc Build API function.
*/
#include <tvm/codegen.h>
#include <tvm/ir_pass.h>
#include <tvm/ir.h>
#include <tvm/buffer.h>
......@@ -10,11 +10,10 @@
#include <utility>
#include <unordered_set>
#include "../pass/ir_util.h"
#include "./ir_util.h"
namespace tvm {
namespace codegen {
using namespace ir;
namespace ir {
inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMArrayFieldKind kind) {
return Call::make(
......@@ -196,5 +195,5 @@ LoweredFunc MakeAPI(Stmt body,
}
return f;
}
} // namespace codegen
} // namespace ir
} // namespace tvm
......@@ -3,7 +3,6 @@
* \file split_host_device.cc
* \brief Split device function from host.
*/
#include <tvm/codegen.h>
#include <tvm/ir.h>
#include <tvm/lowered_func.h>
#include <tvm/ir_pass.h>
......@@ -11,9 +10,7 @@
#include <unordered_map>
namespace tvm {
namespace codegen {
using namespace ir;
namespace ir {
// use/def analysis, also delete unreferenced lets
class IRUseDefAnalysis : public IRMutator {
......@@ -161,7 +158,7 @@ class HostDeviceSplitter : public IRMutator {
private:
Stmt SplitDeviceFunc(Stmt body) {
std::ostringstream os;
os << name_ << "_kernel" << device_funcs_.size();
os << name_ << "__kernel" << device_funcs_.size();
std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>();
// isolate the device function.
IRUseDefAnalysis m;
......@@ -181,6 +178,7 @@ class HostDeviceSplitter : public IRMutator {
}
LoweredFunc f_device(n);
Array<Expr> call_args;
call_args.push_back(StringImm::make(f_device->name));
for (Var arg : n->args) {
call_args.push_back(arg);
}
......@@ -190,7 +188,8 @@ class HostDeviceSplitter : public IRMutator {
}
device_funcs_.emplace_back(f_device);
return Evaluate::make(Call::make(
Int(32), f_device->name, call_args, Call::Extern, f_device));
Int(32), intrinsic::tvm_call_device,
call_args, Call::Intrinsic));
}
// function name
......@@ -214,5 +213,5 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func) {
return HostDeviceSplitter().Split(func);
}
} // namespace codegen
} // namespace ir
} // namespace tvm
......@@ -6,6 +6,8 @@
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_map>
#include "./ir_util.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
......@@ -45,10 +47,9 @@ namespace tvm {
namespace ir {
using Halide::Internal::Region;
using runtime::StorageScope;
using runtime::ThreadScope;
// inliner to inline a function
// the result may not be SSA,
// ConvertSSA need to be applied after this pass
class StorageFlattener : public IRMutator {
public:
explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer) {
......@@ -59,10 +60,54 @@ class StorageFlattener : public IRMutator {
buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e;
}
}
Expr Mutate(Expr expr) final {
expr = IRMutator::Mutate(expr);
const Call* op = expr.as<Call>();
if (op != nullptr && op->call_type == Call::Halide) {
Stmt Flatten(Stmt stmt) {
stmt = this->Mutate(stmt);
StorageScope key; key.rank = 0;
if (move_alloc_out_) {
StorageScope key; key.rank = 0;
stmt = MergeNest(allocs_[key], stmt);
}
return stmt;
}
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->type_key == "realize_scope") {
storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
return this->Mutate(op->body);
} else if (op->type_key == "scope") {
IterVar iv(op->node.node_);
if (iv->thread_tag.length() != 0) {
ThreadScope ts = ThreadScope::make(iv->thread_tag);
curr_thread_scope_.push_back(ts);
Stmt stmt = IRMutator::Mutate_(op, s);
curr_thread_scope_.pop_back();
op = stmt.as<AttrStmt>();
bool first_scope = true;
for (const ThreadScope& t : curr_thread_scope_) {
if (t.rank == ts.rank) first_scope = false;
}
if (first_scope && move_alloc_out_) {
StorageScope key;
key.rank = ts.rank + 1;
std::vector<Stmt>& vec = allocs_[key];
if (vec.size() != 0) {
Stmt body = MergeNest(vec, op->body);
vec.clear();
return AttrStmt::make(
op->node, op->type_key, op->value, body);
}
}
return stmt;
}
}
return IRMutator::Mutate_(op, s);
}
Stmt Mutate_(const Provide* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Provide>();
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
......@@ -70,20 +115,79 @@ class StorageFlattener : public IRMutator {
const BufferEntry& e = it->second;
CHECK(!e.released)
<< "Read a buffer that is already out of scope";
return e.buffer.MakeLoad(e.RelIndex(op->args));
return e.buffer.MakeStore(e.RelIndex(op->args), op->value);
}
Stmt Mutate_(const Realize* op, const Stmt& s) final {
TensorKey key{op->func, op->value_index};
if (buf_map_.count(key)) {
CHECK(buf_map_.at(key).external);
return this->Mutate(op->body);
} else {
return expr;
// create a buffer entry
// TODO(tqchen) allow permutation and inference of index dimension.
BufferEntry e;
e.bounds = op->bounds;
Array<Expr> shape;
for (auto r : e.bounds) {
shape.push_back(r->extent);
}
e.buffer = Buffer(shape, op->type, key.GetName());
buf_map_[key] = e;
Stmt body = this->Mutate(op->body);
buf_map_[key].released = true;
// deduce current storage scope.
auto it = storage_scope_.find(op->func.get());
CHECK(it != storage_scope_.end());
StorageScope key; key.rank = 0;
const std::string& skey = it->second;
if (skey.length() == 0) {
if (curr_thread_scope_.size() != 0) {
key.rank = curr_thread_scope_.back().rank + 1;
}
} else {
key = StorageScope::make(skey);
}
Stmt Mutate(Stmt stmt) final {
const Realize* realize = stmt.as<Realize>();
if (realize != nullptr) {
return HandleRealize(realize);
} else if (stmt.as<Provide>()) {
return HandleProvide(stmt);
if (move_alloc_out_) {
allocs_[key].push_back(
AttrStmt::make(
e.buffer->data, "storage_scope",
StringImm::make(key.to_string()),
Evaluate::make(0)));
allocs_[key].push_back(
Allocate::make(
e.buffer->data, e.buffer->dtype, e.buffer->shape,
make_const(Bool(e.buffer->dtype.lanes()), true),
Evaluate::make(0)));
return body;
} else {
return IRMutator::Mutate(stmt);
Stmt ret = Allocate::make(
e.buffer->data, e.buffer->dtype, e.buffer->shape,
make_const(Bool(e.buffer->dtype.lanes()), true), body);
ret = AttrStmt::make(
e.buffer->data, "storage_scope",
StringImm::make(key.to_string()), ret);
return ret;
}
}
}
Expr Mutate_(const Call* op, const Expr& olde) final {
Expr expr = IRMutator::Mutate_(op, olde);
op = expr.as<Call>();
if (op != nullptr && op->call_type == Call::Halide) {
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
<< "Cannot find allocated buffer for " << key.f;
const BufferEntry& e = it->second;
CHECK(!e.released)
<< "Read a buffer that is already out of scope";
return e.buffer.MakeLoad(e.RelIndex(op->args));
} else {
return expr;
}
}
......@@ -113,54 +217,20 @@ class StorageFlattener : public IRMutator {
}
}
};
// whether move allocation to the outmost scope as possible.
bool move_alloc_out_{true};
// The buffer assignment map
std::unordered_map<TensorKey, BufferEntry> buf_map_;
Stmt HandleRealize(const Realize* op) {
TensorKey key{op->func, op->value_index};
if (buf_map_.count(key)) {
CHECK(buf_map_.at(key).external);
return this->Mutate(op->body);
} else {
// create a buffer entry
// TODO(tqchen) allow permutation and inference of index dimension.
BufferEntry e;
e.bounds = op->bounds;
Array<Expr> shape;
for (auto r : e.bounds) {
shape.push_back(r->extent);
}
e.buffer = Buffer(shape, op->type, key.GetName());
buf_map_[key] = e;
Stmt body = this->Mutate(op->body);
buf_map_[key].released = true;
return Allocate::make(
e.buffer->data, e.buffer->dtype, e.buffer->shape,
make_const(Bool(e.buffer->dtype.lanes()), true), body);
}
}
Stmt HandleProvide(Stmt stmt) {
stmt = IRMutator::Mutate(stmt);
const Provide* op = stmt.as<Provide>();
TensorKey key{op->func, op->value_index};
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
<< "Cannot find allocated buffer for " << key.f;
const BufferEntry& e = it->second;
CHECK(!e.released)
<< "Read a buffer that is already out of scope";
return e.buffer.MakeStore(e.RelIndex(op->args), op->value);
}
std::unordered_map<const Node*, std::string> storage_scope_;
// The current thread scope.
std::vector<ThreadScope> curr_thread_scope_;
// The allocations by rank
std::unordered_map<StorageScope, std::vector<Stmt> > allocs_;
};
Stmt StorageFlatten(Stmt stmt,
Map<Tensor, Buffer> extern_buffer) {
stmt = StorageFlattener(extern_buffer).Mutate(stmt);
stmt = StorageFlattener(extern_buffer).Flatten(stmt);
return stmt;
}
......
/*!
* Copyright (c) 2017 by Contributors
* \file storage_sync.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <unordered_map>
#include <unordered_set>
#include "./ir_util.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
using runtime::StorageScope;
class StorageSyncPlanner : public IRVisitor {
public:
explicit StorageSyncPlanner(StorageScope sync_scope)
: sync_scope_(sync_scope) {}
// only intended to be used once.
// The syncs inserted before each statement
std::unordered_set<const Node*> Plan(Stmt stmt) {
CHECK_EQ(scope_.size(), 0U);
scope_.push_back(std::vector<StmtEntry>());
this->Visit(stmt);
this->PlanSync(false);
return std::move(syncs_inserted_);
}
void Visit_(const Load* op) final {
CHECK(allow_load_);
const Variable* buf = op->buffer_var.as<Variable>();
StorageScope s = GetScope(buf);
if (s == sync_scope_) {
curr_stmt_.access.emplace_back(
AccessEntry(buf, kRead, s));
}
}
void Visit_(const Store* op) final {
allow_load_ = true;
CHECK_EQ(curr_stmt_.access.size(), 0U);
curr_stmt_.stmt = op;
const Variable* buf = op->buffer_var.as<Variable>();
StorageScope s = GetScope(buf);
if (s == sync_scope_) {
curr_stmt_.access.emplace_back(
AccessEntry(buf, kWrite, s));
}
// traverse child
IRVisitor::Visit_(op);
// push to the scope
scope_.back().push_back(curr_stmt_);
// clear access entry.
curr_stmt_.access.clear();
allow_load_ = false;
}
void Visit_(const AttrStmt* op) final {
if (op->type_key == "storage_scope") {
const Variable* buf = op->node.as<Variable>();
storage_scope_[buf] =
StorageScope::make(op->value.as<StringImm>()->value);
}
IRVisitor::Visit_(op);
}
void Visit_(const For* op) final {
scope_.push_back(std::vector<StmtEntry>());
IRVisitor::Visit_(op);
StmtEntry s; s.stmt = op;
s.access = PlanSync(true);
scope_.pop_back();
scope_.back().emplace_back(std::move(s));
}
void Visit_(const Call* op) final {
if (op->is_intrinsic(Call::address_of)) {
const Load *l = op->args[0].as<Load>();
IRVisitor::Visit_(l);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const IfThenElse* op) final {
++condition_counter_;
this->Visit(op->condition);
scope_.push_back(std::vector<StmtEntry>());
this->Visit(op->then_case);
StmtEntry s; s.stmt = op;
s.access = PlanSync(false);
scope_.pop_back();
if (op->else_case.defined()) {
scope_.push_back(std::vector<StmtEntry>());
auto v = PlanSync(false);
scope_.pop_back();
s.access.insert(s.access.end(), v.begin(), v.end());
}
scope_.back().emplace_back(std::move(s));
--condition_counter_;
}
private:
// Storage access type
enum AccessType {
kRead,
kWrite,
kSync
};
// The access entry
struct AccessEntry {
/*! \brief The buffer variable, if any */
const Variable* buffer{nullptr};
/*! \brief The type of access */
AccessType type;
/*! \brief The storage scope */
StorageScope scope;
// constructor
AccessEntry() {}
AccessEntry(const Variable* buffer,
AccessType type,
StorageScope scope)
: buffer(buffer), type(type), scope(scope) {}
};
// The statment entry
struct StmtEntry {
// the associated statement.
const Node* stmt;
std::vector<AccessEntry> access;
};
// Get current storage scope.
StorageScope GetScope(const Variable* buf) const {
auto it = storage_scope_.find(buf);
StorageScope s; s.rank = 0;
if (it == storage_scope_.end()) return s;
return it->second;
}
// Plan the sync
std::vector<AccessEntry> PlanSync(bool is_loop) {
// unsynced reads and writes
std::vector<AccessEntry> reads;
std::vector<AccessEntry> writes;
const std::vector<StmtEntry>& seq = scope_.back();
// if it is a loop, rotate two times to consider effect of loop.
size_t max_seq = seq.size();
if (is_loop) max_seq *= 2;
// simulation based approach to find dependenceies
for (size_t i = 0; i < max_seq; ++i) {
const StmtEntry& s = seq[i % seq.size()];
// check if sync before statement is needed.
bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
// Apply the syncs added already.
if (sync_before_stmt) {
reads.clear();
writes.clear();
}
for (const AccessEntry& acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc)) {
sync_before_stmt = true; break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc)) {
sync_before_stmt = true; break;
}
} else if (acc.type == kSync) {
reads.clear(); writes.clear();
}
}
// If sync is inserted. remove the irrelevant things.
if (sync_before_stmt) {
reads.clear(); writes.clear();
}
// Add the read/write of current statement
for (const AccessEntry& acc : s.access) {
if (acc.type == kRead) {
reads.push_back(acc);
} else if (acc.type == kWrite) {
writes.push_back(acc);
} else if (acc.type == kSync) {
reads.clear(); writes.clear();
}
}
if (sync_before_stmt) {
CHECK_EQ(condition_counter_, 0)
<< "Cannot insert syncs inside condition";
syncs_inserted_.insert(s.stmt);
}
}
// return the exposed entries, remove unecessary ones.
int sync_count = 0;
// head are before first sync, tail are after last sync
std::vector<AccessEntry> head, tail;
for (const StmtEntry& s : seq) {
if (syncs_inserted_.count(s.stmt)) {
if (sync_count != 0) {
tail.clear();
} else {
head.push_back(AccessEntry(nullptr, kSync, sync_scope_));
}
++sync_count;
}
for (const AccessEntry& acc : s.access) {
if (acc.type == kSync) {
if (sync_count != 0) {
tail.clear();
} else {
head.push_back(AccessEntry(nullptr, kSync, sync_scope_));
}
++sync_count;
} else {
if (sync_count != 0) {
tail.push_back(acc);
} else {
head.push_back(acc);
}
}
}
}
head.insert(head.end(), tail.begin(), tail.end());
return head;
}
// find conflicting entry in vec.
static bool FindConflict(const std::vector<AccessEntry>& vec,
const AccessEntry& e) {
for (const AccessEntry& x : vec) {
if (x.buffer == e.buffer) return true;
}
return false;
}
// Whether we are inside condition.
int condition_counter_{0};
// whether load is enabled.
bool allow_load_{false};
// the current free stmt entry.
StmtEntry curr_stmt_;
// access scope
std::vector<std::vector<StmtEntry> > scope_;
// The storage scope of each buffer
std::unordered_map<const Variable*, StorageScope> storage_scope_;
// The syncs inserted before each statement
std::unordered_set<const Node*> syncs_inserted_;
// The sync scope we care about.
StorageScope sync_scope_;
};
class StorageSyncInserter : public IRMutator {
public:
StorageSyncInserter(StorageScope sync_scope,
std::unordered_set<const Node*> syncs)
: sync_scope_(sync_scope), syncs_(syncs) {}
Stmt Mutate(Stmt stmt) final {
stmt = IRMutator::Mutate(stmt);
if (syncs_.count(stmt.get())) {
stmt = Block::make(
Evaluate::make(
Call::make(Int(32), intrinsic::tvm_storage_sync,
{StringImm::make(sync_scope_.to_string())},
Call::Intrinsic)),
stmt);
}
return stmt;
}
StorageScope sync_scope_;
std::unordered_set<const Node*> syncs_;
};
Stmt StorageSync(Stmt stmt, std::string storage_scope) {
StorageScope sync_scope = StorageScope::make(storage_scope);
auto syncs = StorageSyncPlanner(sync_scope).Plan(stmt);
return StorageSyncInserter(sync_scope, syncs).Mutate(stmt);
}
LoweredFunc StorageSync(LoweredFunc f, std::string storage_scope) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = StorageSync(f->body, storage_scope);
return LoweredFunc(n);
}
} // namespace ir
} // namespace tvm
......@@ -13,7 +13,7 @@
#include <mutex>
#include "./cuda_common.h"
#include "../void_addr_args.h"
#include "../thread_axis_args.h"
#include "../thread_storage_scope.h"
namespace tvm {
namespace runtime {
......
......@@ -11,7 +11,7 @@
#include <string>
#include <unordered_map>
#include "../void_addr_args.h"
#include "../thread_axis_args.h"
#include "../thread_storage_scope.h"
namespace tvm {
namespace runtime {
......@@ -87,13 +87,13 @@ class OpenCLWrappedFunc {
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
cl_uint work_dim = static_cast<cl_uint>(thread_axis_cfg_.work_dim());
for (cl_uint i = 0; i < work_dim; ++i) {
wl.work_size[i + 3] *= wl.work_size[i];
wl.work_size[i] *= wl.work_size[i + 3];
}
// launch kernel
OPENCL_CALL(clEnqueueNDRangeKernel(
queue, kernel, work_dim, nullptr,
wl.work_size + 3,
wl.work_size,
wl.work_size + 3,
0, nullptr, nullptr));
}
......
/*!
* Copyright (c) 2017 by Contributors
* \file thread_axis_args.h
* \file thread_storage_scope.h
* \brief Extract thread axis configuration from TVMArgs.
*/
#ifndef TVM_RUNTIME_THREAD_AXIS_ARGS_H_
#define TVM_RUNTIME_THREAD_AXIS_ARGS_H_
#ifndef TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_
#define TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_
#include <tvm/runtime/packed_func.h>
#include <string>
#include <vector>
namespace tvm {
namespace runtime {
/*! \brief class to represent storage scope */
struct StorageScope {
/*! \brief The rank of the storage */
int rank{0};
// comparator
inline bool operator==(const StorageScope& other) const {
return rank == other.rank;
}
inline std::string to_string() const {
switch (rank) {
case 0: return "global";
case 1: return "shared";
case 2: return "local";
default: LOG(FATAL) << "unknown storage scope"; return "";
}
}
/*!
* \brief make storage scope from string
* \param s The string to be parsed.
* \return The storage scope.
*/
static StorageScope make(const std::string& s) {
StorageScope r;
if (s == "global") {
r.rank = 0;
} else if (s == "shared") {
r.rank = 1;
} else if (s == "local") {
r.rank = 2;
} else {
LOG(FATAL) << "unknown storage scope " << s;
}
return r;
}
};
/*! \brief class to represent thread scope */
struct ThreadScope {
/*! \brief The rank of thread scope */
int rank{0};
/*! \brief the dimension index under the rank */
int dim_index{0};
/*!
* \brief make storage scope from string
* \param s The string to be parsed.
* \return The storage scope.
*/
static ThreadScope make(const std::string& s) {
ThreadScope r;
if (s.compare(0, 9, "blockIdx.") == 0) {
r.rank = 0;
r.dim_index = static_cast<int>(s[9] - 'x');
} else if (s.compare(0, 10, "threadIdx.") == 0) {
r.rank = 1;
r.dim_index = static_cast<int>(s[10] - 'x');
} else {
LOG(FATAL) << "Unknown threadscope " << s;
}
return r;
}
};
/*! \brief workload speccification */
struct ThreadWorkLoad {
// array, first three are thread configuration.
......@@ -21,14 +85,14 @@ struct ThreadWorkLoad {
* \return i-th block dim
*/
inline size_t block_dim(size_t i) const {
return work_size[i];
return work_size[i + 3];
}
/*!
* \param i The grid dimension.
* \return i-th grid dim
*/
inline size_t grid_dim(size_t i) const {
return work_size[i + 3];
return work_size[i];
}
};
/*! \brief Thread axis configuration */
......@@ -40,27 +104,9 @@ class ThreadAxisConfig {
std::vector<bool> filled(6, false);
for (size_t i = 0; i < thread_axis_tags.size(); ++i) {
const std::string& tag = thread_axis_tags[i];
if (tag == "threadIdx.x") {
arg_index_map_.push_back(0);
filled[0] = true;
} else if (tag == "threadIdx.y") {
arg_index_map_.push_back(1);
filled[1] = true;
} else if (tag == "threadIdx.z") {
arg_index_map_.push_back(2);
filled[2] = true;
} else if (tag == "blockIdx.x") {
arg_index_map_.push_back(3 + 0);
filled[3] = true;
} else if (tag == "blockIdx.y") {
arg_index_map_.push_back(3 + 1);
filled[3 + 1] = true;
} else if (tag == "blockIdx.z") {
arg_index_map_.push_back(3 + 2);
filled[3 + 2] = true;
} else {
LOG(FATAL) << "do not known thread_tag=" << tag;
}
ThreadScope ts = ThreadScope::make(tag);
arg_index_map_.push_back(ts.rank * 3 + ts.dim_index);
filled[ts.rank * 3 + ts.dim_index] = true;
}
work_dim_ = 3;
for (int i = 0; i < 3; ++i) {
......@@ -103,4 +149,13 @@ class ThreadAxisConfig {
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_THREAD_AXIS_ARGS_H_
namespace std {
template <>
struct hash<::tvm::runtime::StorageScope> {
std::size_t operator()(const ::tvm::runtime::StorageScope& k) const {
return static_cast<size_t>(k.rank);
}
};
} // namespace std
#endif // TVM_RUNTIME_THREAD_STORAGE_SCOPE_H_
......@@ -9,6 +9,7 @@
#include <tvm/schedule_pass.h>
#include "./int_set.h"
#include "./graph.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace schedule {
......@@ -181,24 +182,13 @@ BoundProp(const Array<Operation>& post_order,
// check if scope
bool ScopeRelax(const IterVar& iv, const std::string& scope) {
inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
using runtime::ThreadScope;
using runtime::StorageScope;
if (iv->thread_tag.length() == 0) return false;
if (scope.length() == 0) return false;
static std::unordered_map<std::string, int> scope_rank{
{"global", 0},
{"shared", 1},
{"local", 2}
};
static std::unordered_map<std::string, int> thread_tag_rank{
{"blockIdx.x", 0},
{"blockIdx.y", 0},
{"blockIdx.z", 0},
{"threadIdx.x", 1},
{"threadIdx.y", 1},
{"threadIdx.z", 1}
};
return scope_rank.at(scope) <= thread_tag_rank.at(iv->thread_tag);
return StorageScope::make(scope).rank <= ThreadScope::make(iv->thread_tag).rank;
}
void InferBound(const Stage& stage,
......@@ -248,7 +238,7 @@ void InferBound(const Stage& stage,
}
auto result = BoundProp(post_order, &bp_state);
// Set relaxation
// Set relaxation for the threads in parent.
Map<IterVar, IntSet> relax_set;
Stage s = stage;
while (s->attach_type == kScope) {
......@@ -259,6 +249,7 @@ void InferBound(const Stage& stage,
}
}
}
for (auto iv : stage->op->root_iter_vars()) {
CHECK(result.count(iv));
CHECK(!rmap->count(iv));
......
......@@ -32,7 +32,7 @@ template<typename T>
inline bool GetConst(Expr e, T* out);
template<>
bool GetConst<int64_t>(Expr e, int64_t *out) {
inline bool GetConst<int64_t>(Expr e, int64_t *out) {
if (e.type().is_vector()) return false;
const int64_t *v = as_const_int(e);
if (v) {
......@@ -42,7 +42,7 @@ bool GetConst<int64_t>(Expr e, int64_t *out) {
}
}
template<>
bool GetConst<uint64_t>(Expr e, uint64_t *out) {
inline bool GetConst<uint64_t>(Expr e, uint64_t *out) {
if (e.type().is_vector()) return false;
const uint64_t *v = as_const_uint(e);
if (v) {
......@@ -77,7 +77,7 @@ template<>
inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) {
if (is_zero(b)) return a;
TVM_CONST_PROPAGATION(sub, -);
return ir::Add::make(a, b);
return ir::Sub::make(a, b);
}
template<>
......@@ -91,7 +91,7 @@ inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) {
template<>
inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) {
if (is_one(b)) return a;
return ir::Mul::make(a, b);
return ir::Div::make(a, b);
}
template<>
......
......@@ -11,6 +11,7 @@
#include "../pass/ir_util.h"
#include "./int_set.h"
#include "./graph.h"
#include "./compute_expr.h"
namespace tvm {
namespace schedule {
......@@ -48,6 +49,49 @@ void PassDownFlag(const Stage& s,
}
/*!
* \brief message passing to find if boundary checking on IterVar is needed.
* \param s The stage to be used.
* \param p_state The message passing state
* IterVar->flag
*/
void PassUpBoundCheck(const Stage& s,
const Map<IterVar, Range>& dom_map,
std::unordered_map<IterVar, bool>* p_state) {
auto& state = *p_state;
using Halide::Internal::can_prove;
for (size_t i = s->relations.size(); i != 0; --i) {
IterVarRelation rel = s->relations[i - 1];
if (rel.as<SplitNode>()) {
const SplitNode* s = rel.as<SplitNode>();
bool outer = state.at(s->outer);
bool inner = state.at(s->inner);
Expr factor = dom_map.at(s->inner)->extent;
Expr step = dom_map.at(s->outer)->extent;
if (outer || inner) {
state[s->parent] = true;
} else {
if (can_prove(dom_map.at(s->parent)->extent == factor * step)) {
state[s->parent] = false;
} else {
state[s->parent] = true;
}
}
} else if (rel.as<FuseNode>()) {
const FuseNode* s = rel.as<FuseNode>();
bool fused = state.at(s->fused);
state[s->outer] = fused;
state[s->inner] = fused;
} else if (rel.as<RebaseNode>()) {
const RebaseNode* s = rel.as<RebaseNode>();
state[s->parent] = state.at(s->rebased);
} else {
LOG(FATAL) << "unknown relation type";
}
}
}
/*!
* \brief use message passing to calculate the assignment of each Var inside the loop body.
* \param s The schedule to be used.
* \param dom_map The domain map of each iteration variable's domain
......@@ -107,8 +151,9 @@ MakeLoopNest(const Stage& sch,
const Map<IterVar, Range>& dom_map,
size_t begin_loop,
bool reduce_init_loop,
std::unordered_map<IterVar, Expr>* p_value_map,
const std::unordered_map<IterVar, bool>& skip_iter) {
const std::unordered_map<IterVar, bool>& bound_state,
const std::unordered_map<IterVar, bool>& skip_iter,
std::unordered_map<IterVar, Expr>* p_value_map) {
auto leaf_iter_vars = sch->leaf_iter_vars;
Stmt no_op = Evaluate::make(0);
// create the loop nest
......@@ -167,6 +212,21 @@ MakeLoopNest(const Stage& sch,
}
// message passing to get offset of root iter vars.
PassUpOffset(sch, dom_map, &value_map);
// insert conditions
for (IterVar iv : sch->op->root_iter_vars()) {
if (skip_iter.count(iv)) continue;
Range dom = dom_map.at(iv);
if (bound_state.at(iv)) {
Expr condition = ComputeExpr<Sub>(value_map.at(iv), dom->min) < dom->extent;
nest.back().emplace_back(IfThenElse::make(condition, no_op));
}
CHECK(iv->dom.defined());
if (!reduce_init_loop && !iv->dom.same_as(dom)) {
Expr condition = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min) < iv->dom->extent;
nest.back().emplace_back(IfThenElse::make(condition, no_op));
}
}
return nest;
}
......@@ -175,7 +235,16 @@ Stmt MakeLoop(const Stage& s,
Stmt provide,
Stmt init) {
std::unordered_map<IterVar, Expr> value_map;
auto nest = MakeLoopNest(s, dom_map, 0, false, &value_map, {});
// bound check state.
std::unordered_map<IterVar, bool> bound_state;
for (IterVar iv : s->leaf_iter_vars) {
bound_state[iv] = false;
}
PassUpBoundCheck(s, dom_map, &bound_state);
auto nest = MakeLoopNest(s, dom_map, 0, false,
bound_state, {}, &value_map);
provide = Substitute(provide, value_map);
if (init.defined()) {
// try to find the location to insert the initialization.
......@@ -204,13 +273,13 @@ Stmt MakeLoop(const Stage& s,
}
// skip loops that does not relates to axis.
std::unordered_map<IterVar, bool> skip_iter;
for (size_t i = begin_loop; i < leaf_iter_vars.size(); ++i) {
auto iv = leaf_iter_vars[i];
int flag = reduce_state.at(iv);
if ((flag & 1) == 0) skip_iter[iv] = true;
for (auto kv : reduce_state) {
int flag = kv.second;
if ((flag & 1) == 0) skip_iter[kv.first] = true;
}
auto init_nest = MakeLoopNest(
s, dom_map, begin_loop, true, &init_value_map, skip_iter);
s, dom_map, begin_loop, true,
bound_state, skip_iter, &init_value_map);
init = Substitute(init, init_value_map);
init = MergeNest(init_nest, init);
// common nest
......@@ -250,7 +319,6 @@ Stmt MakeRealize(const ComputeOpNode* op,
void MakeReduction(const ComputeOpNode* op,
const std::vector<Tensor>& tensors,
const Map<IterVar, Range>& dom_map,
Stmt* init,
Stmt* provide) {
Stmt no_op = Evaluate::make(0);
......@@ -279,43 +347,49 @@ void MakeReduction(const ComputeOpNode* op,
*provide = Provide::make(t->op, t->value_index, update_value, args);
}
Stmt MakePipeline(const Stage& sch,
Stmt MakePipeline(const Stage& s,
const Map<IterVar, Range>& dom_map,
Stmt consumer) {
std::vector<Tensor> tensors;
for (int i = 0; i < sch->op->num_outputs(); ++i) {
tensors.emplace_back(sch->op.output(i));
for (int i = 0; i < s->op->num_outputs(); ++i) {
tensors.emplace_back(s->op.output(i));
}
Stmt init, provide;
const ComputeOpNode* compute = sch->op.as<ComputeOpNode>();
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
if (compute) {
if (compute->reduce_axis.size() == 0) {
provide = MakeProvide(compute, tensors);
} else {
MakeReduction(compute, tensors, dom_map, &init, &provide);
MakeReduction(compute, tensors, &init, &provide);
}
} else {
LOG(FATAL) << "not supported op " << sch->op->type_key();
LOG(FATAL) << "not supported op " << s->op->type_key();
}
Stmt producer = MakeLoop(sch, dom_map, provide, init);
producer = ProducerConsumer::make(sch->op, true, producer);
Stmt producer = MakeLoop(s, dom_map, provide, init);
producer = ProducerConsumer::make(s->op, true, producer);
Stmt pipeline = producer;
if (consumer.defined()) {
consumer = ProducerConsumer::make(sch->op, false, consumer);
consumer = ProducerConsumer::make(s->op, false, consumer);
pipeline = Block::make(producer, consumer);
}
if (sch->op.as<ComputeOpNode>()) {
return MakeRealize(sch->op.as<ComputeOpNode>(),
if (s->op.as<ComputeOpNode>()) {
pipeline = MakeRealize(s->op.as<ComputeOpNode>(),
dom_map, tensors, pipeline);
} else {
LOG(FATAL) << "not supported op";
return Stmt();
}
// use attribute to mark scope of the operation.
pipeline = AttrStmt::make(
s->op, "realize_scope",
StringImm::make(s->scope),
pipeline);
return pipeline;
}
// inject the operator's realization on the stmt.
......
import tvm
import numpy as np
def test_gemm():
# graph
nn = 1235
n = tvm.Var('n')
#n = tvm.convert(nn)
m = n
l = n
A = tvm.placeholder((n, l), name='A')
B = tvm.placeholder((m, l), name='B')
AA = tvm.compute(A.shape, lambda *i : A(*i), name="AA")
BB = tvm.compute(B.shape, lambda *i : B(*i), name="BB")
k = tvm.IterVar((0, l), name='k')
CC = tvm.compute(
(n, m),
lambda ii, jj: tvm.sum(AA[ii, k] * BB[jj, k], axis=k),
name='CC')
C = tvm.compute(CC.shape, lambda *i: CC(*i), name="C")
# schedule
s = tvm.Schedule(C.op)
xtile, ytile = 32, 32
s[AA].set_scope("shared")
#s[CC].set_scope("global")
s[BB].set_scope("shared")
scale = 8
num_thread = 8
block_factor = scale * num_thread
block_x = tvm.IterVar(thread_tag="blockIdx.x")
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
block_y = tvm.IterVar(thread_tag="blockIdx.y")
thread_y = tvm.IterVar((0, num_thread), thread_tag="threadIdx.y")
_, yi = s[C].split(C.op.axis[0], factor=block_factor, outer=block_y)
_, xi = s[C].split(C.op.axis[1], factor=block_factor, outer=block_x)
s[C].reorder(block_y, block_x, yi, xi)
_, yi = s[C].split(yi, outer=thread_y)
_, xi = s[C].split(xi, outer=thread_x)
s[C].reorder(thread_y, thread_x, yi, xi)
yo, xo = CC.op.axis
s[CC].reorder(k, yo, xo)
s[CC].compute_at(s[C], thread_x)
s[AA].compute_at(s[CC], k)
s[BB].compute_at(s[CC], k)
_, xi = s[AA].split(s[AA].op.axis[0], outer=thread_y)
_, xi = s[AA].split(xi, outer=thread_x)
_, xi = s[BB].split(s[BB].op.axis[0], outer=thread_y)
_, xi = s[BB].split(xi, outer=thread_x)
# lowering test
s.normalize()
def check_device(target):
codes = []
f = tvm.build(s, [A, B, C], target, record_codes=codes)
for c in codes[1:]:
print(c)
if target == "cuda":
ctx = tvm.gpu(0)
else:
ctx = tvm.cl(0)
if not ctx.enabled:
return
# launch the kernel.
n = nn
m = n
l = n
a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
b_np = np.random.uniform(size=(m, l)).astype(B.dtype)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
tvm.init_opencl()
check_device("cuda")
check_device("opencl")
if __name__ == "__main__":
test_gemm()
......@@ -24,8 +24,8 @@ def test_add_pipeline():
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.codegen.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 3)
fsplits = tvm.codegen.SplitHostDevice(fapi)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 3)
fsplits = tvm.ir_pass.SplitHostDevice(fapi)
def check_cuda():
output_ssa = False
......
......@@ -18,7 +18,7 @@ def test_makeapi():
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
num_packed_args = 2
f = tvm.codegen.MakeAPI(stmt, "myadd", [n, Ab, Bb, Cb], num_packed_args)
f = tvm.ir_pass.MakeAPI(stmt, "myadd", [n, Ab, Bb, Cb], num_packed_args)
assert(f.handle_data_type[Ab.data].dtype == Ab.dtype)
assert(len(f.args) == 5)
output_ssa = False
......
import tvm
def test_storage_sync():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.Schedule(A2.op)
block_x = tvm.IterVar(thread_tag="blockIdx.x")
xo, xi = s[A2].split(A2.op.axis[0], factor=8, outer=block_x)
s[A1].compute_at(s[A2], xo)
s[A1].set_scope("shared")
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.collections.Map)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
f = tvm.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 2)
flist = tvm.ir_pass.SplitHostDevice(f)
f = flist[1]
f = tvm.ir_pass.StorageSync(f, "shared")
print(f.body)
if __name__ == "__main__":
test_storage_sync()
......@@ -16,9 +16,7 @@ def test_stack_vm_basic():
n = tvm.Var('n')
Ab = tvm.Buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm_call_global("tvm_call_back_get_shape", Ab.shape[0]))
print(stmt)
fapi = tvm.codegen.MakeAPI(stmt, "print_shape", [Ab], 1)
print(fapi.body)
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi)
f(a)
......@@ -41,8 +39,7 @@ def test_stack_vm_loop():
tvm.make.Load(dtype, Ab.data, i) + 1,
i + 1),
tvm.make.Evaluate(tvm_call_global("tvm_stack_vm_print", i))))
print(stmt)
fapi = tvm.codegen.MakeAPI(stmt, "ramp", [Ab], 1)
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
......@@ -63,8 +60,7 @@ def test_stack_vm_cond():
tvm.make.Load(dtype, Ab.data, i) + 1, i + 1),
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 2, i + 1)))
print(stmt)
fapi = tvm.codegen.MakeAPI(stmt, "test", [Ab], 1)
fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment