Commit f6c043eb by Tianqi Chen Committed by GitHub

[LLVM/RUNTIME] Support Parallel for on CPU (#54)

parent 2f462cca
...@@ -173,11 +173,12 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -173,11 +173,12 @@ LoweredFunc MakeAPI(Stmt body,
int num_unpacked_args); int num_unpacked_args);
/*! /*!
* \brief Count number of undefined vars in f. * \brief Find undefined vars in the statment.
* \param f The function to be checked. * \param stmt The function to be checked.
* \return Number of undefined vars. * \param defs The vars that is defined.
* \return Array of undefined vars.
*/ */
Array<Var> UndefinedVars(const LoweredFunc& f); Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
/*! /*!
* \brief Split the function into a host function and device functions. * \brief Split the function into a host function and device functions.
......
...@@ -226,6 +226,18 @@ TVM_DLL int TVMModPreCompile(TVMModuleHandle mod, ...@@ -226,6 +226,18 @@ TVM_DLL int TVMModPreCompile(TVMModuleHandle mod,
TVMContext ctx); TVMContext ctx);
/*! /*!
* \brief Free the Module
* \param mod The module to be freed.
*
* \note This may not free up the module's resources.
* If there is active TVMFunctionHandle uses the module
* Or if this module is imported by another active module.
*
* The all functions remains valid until TVMFuncFree is called.
*/
TVM_DLL int TVMModFree(TVMModuleHandle mod);
/*!
* \brief Backend function for modules to get function * \brief Backend function for modules to get function
* from its environment mod_node (its imports and global function). * from its environment mod_node (its imports and global function).
* *
...@@ -242,17 +254,25 @@ TVM_DLL int TVMModPreCompile(TVMModuleHandle mod, ...@@ -242,17 +254,25 @@ TVM_DLL int TVMModPreCompile(TVMModuleHandle mod,
TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node, TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name, const char* func_name,
TVMFunctionHandle *out); TVMFunctionHandle *out);
/*! /*!
* \brief Free the Module * \brief Backend function for running parallel for loop.
* \param mod The module to be freed.
* *
* \note This may not free up the module's resources. * \note This API is supposed to be used by backend,
* If there is active TVMFunctionHandle uses the module * it is not supposed to be used by user.
* Or if this module is imported by another active module.
* *
* The all functions remains valid until TVMFuncFree is called. * \param begin The start of iteration.
* \param end The end of iteration.
* \param lambda The lambda function to be executed.
* \param env The environment of lambda function.
*
* \return 0 when no error is thrown, -1 when failure happens
*/ */
TVM_DLL int TVMModFree(TVMModuleHandle mod); TVM_DLL int TVMBackendParallelFor(
int64_t begin,
int64_t end,
int (*lambda)(int64_t begin, int64_t end, void* env),
void* env);
/*! /*!
* \brief Free the function when it is no longer needed. * \brief Free the function when it is no longer needed.
......
...@@ -34,7 +34,8 @@ enum AttachType : int { ...@@ -34,7 +34,8 @@ enum AttachType : int {
/*! \brief IterVar type */ /*! \brief IterVar type */
enum IterVarType : int { enum IterVarType : int {
kUnrolled = 1, kUnrolled = 1,
kVectorized = 2 kVectorized = 2,
kParallel = 3
}; };
/*! \brief Stage, contains scheduling for a stage of computation. */ /*! \brief Stage, contains scheduling for a stage of computation. */
...@@ -153,6 +154,12 @@ class Stage : public NodeRef { ...@@ -153,6 +154,12 @@ class Stage : public NodeRef {
*/ */
Stage& unroll(IterVar var); // NOLINT(*) Stage& unroll(IterVar var); // NOLINT(*)
/*! /*!
* \brief Parallelize iteration.
* \param var The axis to be parallelized.
* \return reference to self.
*/
Stage& parallel(IterVar var); // NOLINT(*)
/*!
* \brief whether the stage has been scheduled. * \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled. * \return whether the stage has been scheduled.
*/ */
......
...@@ -257,3 +257,13 @@ class Stage(NodeBase): ...@@ -257,3 +257,13 @@ class Stage(NodeBase):
The iteration to be unrolled. The iteration to be unrolled.
""" """
_api_internal._StageUnroll(self, var) _api_internal._StageUnroll(self, var)
def parallel(self, var):
"""Parallelize the iteration.
Parameters
----------
var : IterVar
The iteration to be parallelized.
"""
_api_internal._StageParallel(self, var)
...@@ -280,6 +280,12 @@ TVM_REGISTER_API(_StageVectorize) ...@@ -280,6 +280,12 @@ TVM_REGISTER_API(_StageVectorize)
.vectorize(args[1]); .vectorize(args[1]);
}); });
TVM_REGISTER_API(_StageParallel)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.parallel(args[1]);
});
TVM_REGISTER_API(_ScheduleNormalize) TVM_REGISTER_API(_ScheduleNormalize)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Schedule() args[0].operator Schedule()
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#ifdef TVM_LLVM_VERSION #ifdef TVM_LLVM_VERSION
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/ir_pass.h>
#include "./codegen_llvm.h" #include "./codegen_llvm.h"
#include "../../arithmetic/compute_expr.h" #include "../../arithmetic/compute_expr.h"
...@@ -30,6 +31,7 @@ void CodeGenLLVM::Init(const std::string& module_name, ...@@ -30,6 +31,7 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_int8_ = llvm::Type::getInt8Ty(*ctx); t_int8_ = llvm::Type::getInt8Ty(*ctx);
t_int16_ = llvm::Type::getInt16Ty(*ctx); t_int16_ = llvm::Type::getInt16Ty(*ctx);
t_int32_ = llvm::Type::getInt32Ty(*ctx); t_int32_ = llvm::Type::getInt32Ty(*ctx);
t_int64_ = llvm::Type::getInt64Ty(*ctx);
t_float64_ = llvm::Type::getDoubleTy(*ctx); t_float64_ = llvm::Type::getDoubleTy(*ctx);
t_tvm_index_ = llvm::Type::getIntNTy(*ctx, sizeof(tvm_index_t) * 8); t_tvm_index_ = llvm::Type::getIntNTy(*ctx, sizeof(tvm_index_t) * 8);
t_tvm_context_ = llvm::StructType::create({t_int_, t_int_}); t_tvm_context_ = llvm::StructType::create({t_int_, t_int_});
...@@ -43,6 +45,8 @@ void CodeGenLLVM::Init(const std::string& module_name, ...@@ -43,6 +45,8 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_tvm_type_, t_tvm_type_,
t_tvm_context_}); t_tvm_context_});
t_tvm_value_ = llvm::StructType::create({t_float64_}); t_tvm_value_ = llvm::StructType::create({t_float64_});
t_f_tvm_par_for_lambda_ = llvm::FunctionType::get(
t_int_, {t_int64_, t_int64_, t_void_p_}, false);
md_builder_.reset(new llvm::MDBuilder(*ctx)); md_builder_.reset(new llvm::MDBuilder(*ctx));
md_very_likely_branch_ = md_very_likely_branch_ =
md_builder_->createBranchWeights(1 << 30, 0); md_builder_->createBranchWeights(1 << 30, 0);
...@@ -70,7 +74,11 @@ void CodeGenLLVM::Init(const std::string& module_name, ...@@ -70,7 +74,11 @@ void CodeGenLLVM::Init(const std::string& module_name,
f_tvm_api_set_last_error_ = llvm::Function::Create( f_tvm_api_set_last_error_ = llvm::Function::Create(
llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false), llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false),
llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get()); llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
f_tvm_parallel_for_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {
t_int64_, t_int64_, t_f_tvm_par_for_lambda_->getPointerTo(), t_void_p_}
, false),
llvm::Function::ExternalLinkage, "TVMBackendParallelFor", module_.get());
this->InitTarget(target_triple); this->InitTarget(target_triple);
// initialize builder // initialize builder
builder_.reset(new IRBuilder(*ctx)); builder_.reset(new IRBuilder(*ctx));
...@@ -141,7 +149,9 @@ void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) { ...@@ -141,7 +149,9 @@ void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
} }
llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_); llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(block); builder_->SetInsertPoint(block);
builder_->CreateRet(builder_->CreateCall(f, args)); llvm::CallInst* call = builder_->CreateCall(f, args);
call->setTailCall(true);
builder_->CreateRet(call);
} }
class FPassManager : public llvm::legacy::FunctionPassManager { class FPassManager : public llvm::legacy::FunctionPassManager {
...@@ -545,7 +555,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { ...@@ -545,7 +555,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
return nullptr; return nullptr;
} }
llvm::BasicBlock* CodeGenLLVM::CheckPackedCallSuccess(llvm::Value* retcode) { llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) {
// create emit codes that checks and load the function. // create emit codes that checks and load the function.
using llvm::BasicBlock; using llvm::BasicBlock;
BasicBlock* fail_block = BasicBlock::Create( BasicBlock* fail_block = BasicBlock::Create(
...@@ -563,34 +573,15 @@ llvm::BasicBlock* CodeGenLLVM::CheckPackedCallSuccess(llvm::Value* retcode) { ...@@ -563,34 +573,15 @@ llvm::BasicBlock* CodeGenLLVM::CheckPackedCallSuccess(llvm::Value* retcode) {
return end_block; return end_block;
} }
void CodeGenLLVM::Visit_(const For* op) { void CodeGenLLVM::Visit_(const For* op) {
using llvm::BasicBlock;
BasicBlock* for_head = BasicBlock::Create(
*ctx_, "for_head", function_);
BasicBlock* for_body = BasicBlock::Create(
*ctx_, "for_body", function_);
BasicBlock* for_end = BasicBlock::Create(
*ctx_, "for_end", function_);
BasicBlock* pre_block = builder_->GetInsertBlock();
CHECK(is_zero(op->min)); CHECK(is_zero(op->min));
Type t = op->min.type(); if (op->for_type == ForType::Serial) {
llvm::Value* init = ConstInt32(0); CreateSerialFor(ConstInt32(0), MakeValue(op->extent),
llvm::Value* extent = MakeValue(op->extent); op->loop_var, op->body);
builder_->CreateBr(for_head); } else if (op->for_type == ForType::Parallel) {
CreateParallelFor(op);
builder_->SetInsertPoint(for_head); } else {
llvm::PHINode* index = builder_->CreatePHI(LLVMType(t), 2); LOG(FATAL) << "cannot handle for type " << op->for_type;
index->addIncoming(init, pre_block); }
llvm::Value* cond = CreateLT(t, index, extent);
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_);
// body of for
builder_->SetInsertPoint(for_body);
var_map_[op->loop_var.get()] = index;
this->Visit(op->body);
llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1));
index->addIncoming(next_index, builder_->GetInsertBlock());
builder_->CreateBr(for_head);
// end of for
builder_->SetInsertPoint(for_end);
} }
void CodeGenLLVM::Visit_(const IfThenElse* op) { void CodeGenLLVM::Visit_(const IfThenElse* op) {
...@@ -807,7 +798,7 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) { ...@@ -807,7 +798,7 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
llvm::Value* ctx = builder_->CreateLoad(gv_mod_ctx_); llvm::Value* ctx = builder_->CreateLoad(gv_mod_ctx_);
llvm::Value* retcode = builder_->CreateCall( llvm::Value* retcode = builder_->CreateCall(
f_tvm_get_func_from_env_, {ctx, GetConstString(fname), out}); f_tvm_get_func_from_env_, {ctx, GetConstString(fname), out});
init_block = CheckPackedCallSuccess(retcode); init_block = CheckCallSuccess(retcode);
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align); llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
builder_->CreateBr(end_block); builder_->CreateBr(end_block);
// end block // end block
...@@ -846,7 +837,7 @@ llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) { ...@@ -846,7 +837,7 @@ llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
} }
llvm::Value* ret_value = builder_->CreateAlloca(t_tvm_value_); llvm::Value* ret_value = builder_->CreateAlloca(t_tvm_value_);
llvm::Value* ret_tcode = builder_->CreateAlloca(t_int_); llvm::Value* ret_tcode = builder_->CreateAlloca(t_int_);
CheckPackedCallSuccess( CheckCallSuccess(
builder_->CreateCall( builder_->CreateCall(
f_tvm_func_call_, f_tvm_func_call_,
{handle, targs, tcodes, ConstInt32(nargs), ret_value, ret_tcode})); {handle, targs, tcodes, ConstInt32(nargs), ret_value, ret_tcode}));
...@@ -934,6 +925,94 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) { ...@@ -934,6 +925,94 @@ llvm::Value* CodeGenLLVM::GetConstString(const std::string& str) {
} }
} }
void CodeGenLLVM::CreateParallelFor(const For* op) {
using llvm::BasicBlock;
llvm::Value* min = MakeValue(op->min);
llvm::Value* extent = MakeValue(op->extent);
min = builder_->CreateIntCast(min, t_int64_, op->min.type().is_int());
extent = builder_->CreateIntCast(extent, t_int64_, op->min.type().is_int());
// fields to be packed into closure.
Var loop_var(op->loop_var.node_);
Array<Var> vfields = ir::UndefinedVars(op->body, {loop_var});
std::vector<llvm::Type*> fields;
for (Var v : vfields) {
auto it = var_map_.find(v.get());
CHECK(it != var_map_.end());
fields.push_back(it->second->getType());
}
// closure data
llvm::StructType* tcdata = llvm::StructType::create(fields);
llvm::Function* f = llvm::Function::Create(
t_f_tvm_par_for_lambda_,
llvm::Function::PrivateLinkage,
"__tvm_par_for_lambda", module_.get());
// allocate and setup the closure, call the closure.
llvm::Value* cdata = builder_->CreateAlloca(tcdata, ConstInt32(1));
llvm::Value* zero = ConstInt32(0);
for (size_t i = 0; i < vfields.size(); ++i) {
builder_->CreateStore(
var_map_.at(vfields[i].get()),
builder_->CreateInBoundsGEP(cdata, {zero, ConstInt32(i)}));
}
BasicBlock* par_for_end = CheckCallSuccess(
builder_->CreateCall(
f_tvm_parallel_for_,
{min, extent, f, builder_->CreatePointerCast(cdata, t_void_p_)}));
// Setup the closure function.
BasicBlock *lambda_entry = BasicBlock::Create(*ctx_, "entry", f);
builder_->SetInsertPoint(lambda_entry);
auto it = f->arg_begin();
llvm::Value* begin = &(*it++);
llvm::Value* end = &(*it++);
cdata = &(*it++);
begin = CreateCast(Int(64), op->loop_var.type(), begin);
end = CreateCast(Int(64), op->loop_var.type(), end);
cdata = builder_->CreatePointerCast(cdata, tcdata->getPointerTo());
// setup new variable map, swap it with current var context.
std::unordered_map<const Variable*, llvm::Value*> new_vmap;
for (size_t i = 0; i < vfields.size(); ++i) {
new_vmap[vfields[i].get()] =
builder_->CreateLoad(builder_->CreateInBoundsGEP(
cdata, {zero, ConstInt32(i)}));
}
std::swap(function_, f);
std::swap(new_vmap, var_map_);
CreateSerialFor(begin, end, op->loop_var, op->body);
builder_->CreateRet(ConstInt32(0));
// swap the var map back, now we are back on track.
std::swap(new_vmap, var_map_);
std::swap(function_, f);
builder_->SetInsertPoint(par_for_end);
}
void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end,
const VarExpr& loop_var, const Stmt& body) {
using llvm::BasicBlock;
Type t = loop_var.type();
BasicBlock* for_head = BasicBlock::Create(
*ctx_, "for_head", function_);
BasicBlock* for_body = BasicBlock::Create(
*ctx_, "for_body", function_);
BasicBlock* for_end = BasicBlock::Create(
*ctx_, "for_end", function_);
BasicBlock* pre_block = builder_->GetInsertBlock();
builder_->CreateBr(for_head);
builder_->SetInsertPoint(for_head);
llvm::PHINode* index = builder_->CreatePHI(begin->getType(), 2);
index->addIncoming(begin, pre_block);
llvm::Value* cond = CreateLT(t, index, end);
builder_->CreateCondBr(cond, for_body, for_end, md_very_likely_branch_);
// body of for
builder_->SetInsertPoint(for_body);
var_map_[loop_var.get()] = index;
this->Visit(body);
llvm::Value* next_index = CreateAdd(t, index, ConstInt32(1));
index->addIncoming(next_index, builder_->GetInsertBlock());
builder_->CreateBr(for_head);
// end of for
builder_->SetInsertPoint(for_end);
}
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_LLVM_VERSION #endif // TVM_LLVM_VERSION
...@@ -152,10 +152,12 @@ class CodeGenLLVM : public IRVisitor { ...@@ -152,10 +152,12 @@ class CodeGenLLVM : public IRVisitor {
llvm::StructType* t_tvm_type_{nullptr}; llvm::StructType* t_tvm_type_{nullptr};
llvm::StructType* t_tvm_array_{nullptr}; llvm::StructType* t_tvm_array_{nullptr};
llvm::StructType* t_tvm_value_{nullptr}; llvm::StructType* t_tvm_value_{nullptr};
llvm::FunctionType* t_f_tvm_par_for_lambda_{nullptr};
// tvm api functions // tvm api functions
llvm::Function* f_tvm_func_call_{nullptr}; llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr}; llvm::Function* f_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr}; llvm::Function* f_tvm_api_set_last_error_{nullptr};
llvm::Function* f_tvm_parallel_for_{nullptr};
// The acting body // The acting body
llvm::BasicBlock* block_{nullptr}; llvm::BasicBlock* block_{nullptr};
// Last value returned codegen call. // Last value returned codegen call.
...@@ -176,10 +178,15 @@ class CodeGenLLVM : public IRVisitor { ...@@ -176,10 +178,15 @@ class CodeGenLLVM : public IRVisitor {
llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index); llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value); llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
llvm::Value* GetPackedFuncHandle(const std::string& str); llvm::Value* GetPackedFuncHandle(const std::string& str);
// Create parallel for.
void CreateParallelFor(const For* op);
// Create serial for
void CreateSerialFor(llvm::Value* begin, llvm::Value* end,
const VarExpr& loop_var, const Stmt& body);
// Check if the call to packed function is successful // Check if the call to packed function is successful
// if not directly finalize function and pass on return code. // if not directly finalize function and pass on return code.
// return the end block after the check // return the end block after the check
llvm::BasicBlock* CheckPackedCallSuccess(llvm::Value* retcode); llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
// Initialize target // Initialize target
void InitTarget(const std::string& target); void InitTarget(const std::string& target);
// Add a function to set global module context // Add a function to set global module context
......
/*!
* Copyright (c) 2017 by Contributors
* \file lowered_func.cc
*/
#include <tvm/lowered_func.h>
namespace tvm {
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<LoweredFuncNode>([](const LoweredFuncNode *op, IRPrinter *p) {
p->stream << "LoweredFunc(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(LoweredFuncNode);
} // namespace tvm
...@@ -188,7 +188,7 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -188,7 +188,7 @@ LoweredFunc MakeAPI(Stmt body,
n->is_packed_func = num_unpacked_args == 0; n->is_packed_func = num_unpacked_args == 0;
n->body = MergeNest({seq_init, seq_check}, body); n->body = MergeNest({seq_init, seq_check}, body);
LoweredFunc f(n); LoweredFunc f(n);
Array<Var> undefined = UndefinedVars(f); Array<Var> undefined = UndefinedVars(f->body, f->args);
if (undefined.size() != 0) { if (undefined.size() != 0) {
std::ostringstream os; std::ostringstream os;
for (Var v : undefined) { for (Var v : undefined) {
......
...@@ -220,12 +220,12 @@ class HostDeviceSplitter : public IRMutator { ...@@ -220,12 +220,12 @@ class HostDeviceSplitter : public IRMutator {
}; };
Array<Var> UndefinedVars(const LoweredFunc& f) { Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
IRUseDefAnalysis m; IRUseDefAnalysis m;
for (Var arg : f->args) { for (Var arg : args) {
m.use_count_[arg.get()] = 0; m.use_count_[arg.get()] = 0;
} }
m.Mutate(f->body); m.Mutate(stmt);
return m.undefined_; return m.undefined_;
} }
......
...@@ -7,8 +7,11 @@ ...@@ -7,8 +7,11 @@
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <dmlc/timer.h>
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <cstdlib>
#include <thread>
#include "./runtime_base.h" #include "./runtime_base.h"
#include "./device_api.h" #include "./device_api.h"
...@@ -71,6 +74,24 @@ using namespace tvm::runtime; ...@@ -71,6 +74,24 @@ using namespace tvm::runtime;
struct TVMRuntimeEntry { struct TVMRuntimeEntry {
std::string ret_str; std::string ret_str;
std::string last_error; std::string last_error;
// threads used in parallel for
std::vector<std::thread> par_threads;
// errors created in parallel for.
std::vector<std::string> par_errors;
// number of parallel threads
int num_par_threads{1};
TVMRuntimeEntry() {
const char *val = getenv("TVM_NUM_THREADS");
if (val == nullptr) {
val = getenv("OMP_NUM_THREADS");
}
if (val != nullptr) {
num_par_threads = atoi(val);
} else {
num_par_threads = std::thread::hardware_concurrency();
}
}
}; };
typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore; typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore;
...@@ -123,6 +144,12 @@ int TVMModPreCompile(TVMModuleHandle mod, ...@@ -123,6 +144,12 @@ int TVMModPreCompile(TVMModuleHandle mod,
API_END(); API_END();
} }
int TVMModFree(TVMModuleHandle mod) {
API_BEGIN();
delete static_cast<Module*>(mod);
API_END();
}
int TVMBackendGetFuncFromEnv(void* mod_node, int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name, const char* func_name,
TVMFunctionHandle *func) { TVMFunctionHandle *func) {
...@@ -132,10 +159,44 @@ int TVMBackendGetFuncFromEnv(void* mod_node, ...@@ -132,10 +159,44 @@ int TVMBackendGetFuncFromEnv(void* mod_node,
API_END(); API_END();
} }
int TVMModFree(TVMModuleHandle mod) { int TVMBackendParallelFor(
API_BEGIN(); int64_t begin,
delete static_cast<Module*>(mod); int64_t end,
API_END(); int (*lambda)(int64_t begin, int64_t end, void* env),
void* env) {
TVMRuntimeEntry* rt = TVMAPIRuntimeStore::Get();
int nthread = rt->num_par_threads;
rt->par_threads.resize(nthread);
rt->par_errors.clear();
rt->par_errors.resize(nthread);
int64_t step = (end - begin + nthread - 1) / nthread;
auto fexec = [lambda, env, begin, end, step, rt](int i) {
int64_t ibegin = std::min(end, begin + step * i);
int64_t iend = std::min(end, begin + step * (i + 1));
int rv = (*lambda)(ibegin, iend, env);
if (rv != 0) {
std::ostringstream os;
os << "Thread " << i << " error:" << TVMGetLastError();
rt->par_errors[i] = os.str();
}
};
for (int i = 0; i < nthread; ++i) {
rt->par_threads[i] = std::thread(fexec, i);
}
int ret = 0;
for (int i = 0; i < nthread; ++i) {
rt->par_threads[i].join();
if (rt->par_errors[i].length() != 0) ret = -1;
}
if (ret == 0) return ret;
std::ostringstream os;
for (int i = 0; i < nthread; ++i) {
if (rt->par_errors[i].length() != 0) {
os << rt->par_errors[i] << '\n';
}
}
rt->last_error = os.str();
return -1;
} }
int TVMFuncFree(TVMFunctionHandle func) { int TVMFuncFree(TVMFunctionHandle func) {
......
...@@ -69,6 +69,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -69,6 +69,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
switch (op->iter_type) { switch (op->iter_type) {
case kUnrolled: p->stream << "unroll"; break; case kUnrolled: p->stream << "unroll"; break;
case kVectorized: p->stream << "vectorize"; break; case kVectorized: p->stream << "vectorize"; break;
case kParallel: p->stream << "parallel"; break;
} }
}); });
...@@ -246,6 +247,11 @@ Stage& Stage::unroll(IterVar var) { // NOLINT(*) ...@@ -246,6 +247,11 @@ Stage& Stage::unroll(IterVar var) { // NOLINT(*)
return *this; return *this;
} }
Stage& Stage::parallel(IterVar var) { // NOLINT(*)
SetAttr(operator->(), var, IterVarAttr(kParallel));
return *this;
}
Schedule::Schedule(Array<Operation> ops) { Schedule::Schedule(Array<Operation> ops) {
auto n = std::make_shared<ScheduleNode>(); auto n = std::make_shared<ScheduleNode>();
n->outputs = ops; n->outputs = ops;
......
...@@ -189,6 +189,7 @@ MakeLoopNest(const Stage& sch, ...@@ -189,6 +189,7 @@ MakeLoopNest(const Stage& sch,
if (sch->iter_var_attrs.count(iv)) { if (sch->iter_var_attrs.count(iv)) {
switch (sch->iter_var_attrs[iv]->iter_type) { switch (sch->iter_var_attrs[iv]->iter_type) {
case kUnrolled: for_type = ForType::Unrolled; break; case kUnrolled: for_type = ForType::Unrolled; break;
case kParallel: for_type = ForType::Parallel; break;
case kVectorized: for_type = ForType::Vectorized; break; case kVectorized: for_type = ForType::Vectorized; break;
} }
} }
......
import tvm
import numpy as np
def test_llvm_add_pipeline():
n = tvm.Var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.Schedule(C.op)
s[C].parallel(C.op.axis[0])
def check_llvm():
if not tvm.codegen.enabled("llvm"):
return
# build and invoke the kernel.
f = tvm.build(s, [A, B, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
n = 10270 * 2460
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
for i in range(1000):
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
check_llvm()
if __name__ == "__main__":
test_llvm_add_pipeline()
...@@ -78,40 +78,7 @@ def test_stack_vm_cond(): ...@@ -78,40 +78,7 @@ def test_stack_vm_cond():
np.testing.assert_equal(a.asnumpy(), y) np.testing.assert_equal(a.asnumpy(), y)
run_jit(fapi, check) run_jit(fapi, check)
def test_llvm_add_pipeline():
n = tvm.Var('n')
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.Schedule(C.op)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
Bb = tvm.Buffer(B.shape, B.dtype, name='B')
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0)
def check_llvm():
if not tvm.codegen.enabled("llvm"):
return
# build and invoke the kernel.
f = tvm.codegen.build(fapi, "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx)
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
check_llvm()
if __name__ == "__main__": if __name__ == "__main__":
test_stack_vm_basic() test_stack_vm_basic()
test_stack_vm_cond() test_stack_vm_cond()
test_stack_vm_loop() test_stack_vm_loop()
test_llvm_add_pipeline()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment