Commit 45597d00 by Tianqi Chen Committed by GitHub

[LANG/PASS] Support Vectorize (#37)

parent 6a62beb2
...@@ -62,6 +62,7 @@ class IRMutator { ...@@ -62,6 +62,7 @@ class IRMutator {
virtual Stmt Mutate_(const Realize* op, const Stmt& s); virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s); virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s); virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Expr Mutate_(const Call* op, const Expr& e); virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& s); virtual Expr Mutate_(const Load* op, const Expr& s);
virtual Expr Mutate_(const Variable* op, const Expr& e); virtual Expr Mutate_(const Variable* op, const Expr& e);
......
...@@ -112,6 +112,12 @@ Stmt StorageFlatten(Stmt stmt, ...@@ -112,6 +112,12 @@ Stmt StorageFlatten(Stmt stmt,
Stmt UnrollLoop(Stmt stmt, int max_auto_step); Stmt UnrollLoop(Stmt stmt, int max_auto_step);
/*! /*!
* \brief vectorize the constant loops
* \param stmt The statment to be vectorized.
*/
Stmt VectorizeLoop(Stmt stmt);
/*!
* \brief Make an user callable API LoweredFunc. * \brief Make an user callable API LoweredFunc.
* *
* The main task of this function is to create code to : * The main task of this function is to create code to :
......
...@@ -18,6 +18,8 @@ class StageNode; ...@@ -18,6 +18,8 @@ class StageNode;
class ScheduleNode; class ScheduleNode;
// Node container for IterVarRelation // Node container for IterVarRelation
class IterVarRelationNode; class IterVarRelationNode;
// Attribute of itervar.
class IterVarAttrNode;
/*! \brief the attachment type */ /*! \brief the attachment type */
enum AttachType : int { enum AttachType : int {
...@@ -27,6 +29,12 @@ enum AttachType : int { ...@@ -27,6 +29,12 @@ enum AttachType : int {
kScope = 3 kScope = 3
}; };
/*! \brief IterVar type */
enum IterVarType : int {
kUnrolled = 1,
kVectorized = 2
};
/*! \brief Stage, contains scheduling for a stage of computation. */ /*! \brief Stage, contains scheduling for a stage of computation. */
class Stage : public NodeRef { class Stage : public NodeRef {
public: public:
...@@ -124,11 +132,22 @@ class Stage : public NodeRef { ...@@ -124,11 +132,22 @@ class Stage : public NodeRef {
IterVar* p_x_inner, IterVar* p_y_inner, IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor); Expr x_factor, Expr y_factor);
/*! /*!
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
* \return reference to self.
*/
Stage& vectorize(IterVar var); // NOLINT(*)
/*!
* \brief Unroll iteration.
* \param var The axis to be vectorized.
* \return reference to self.
*/
Stage& unroll(IterVar var); // NOLINT(*)
/*!
* \brief whether the stage has been scheduled. * \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled. * \return whether the stage has been scheduled.
*/ */
inline bool is_scheduled() const; inline bool is_scheduled() const;
// declare container type // declare container type
using ContainerType = StageNode; using ContainerType = StageNode;
}; };
...@@ -193,6 +212,21 @@ class IterVarRelation : public NodeRef { ...@@ -193,6 +212,21 @@ class IterVarRelation : public NodeRef {
inline const IterVarRelationNode* operator->() const; inline const IterVarRelationNode* operator->() const;
}; };
/*!
* \brief Additional scheduable attributes about IterVar.
*/
class IterVarAttr : public NodeRef {
public:
IterVarAttr() {}
explicit IterVarAttr(IterVarType t);
explicit IterVarAttr(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const IterVarAttrNode* operator->() const;
};
// defintion of node containers // defintion of node containers
/*! /*!
* \brief represents the schedule of the tensor * \brief represents the schedule of the tensor
...@@ -223,6 +257,8 @@ class StageNode : public Node { ...@@ -223,6 +257,8 @@ class StageNode : public Node {
Array<IterVar> leaf_iter_vars; Array<IterVar> leaf_iter_vars;
/*! \brief The relation bwteen of IterVars */ /*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations; Array<IterVarRelation> relations;
/*! \brief additional attributes about iter var. */
Map<IterVar, IterVarAttr> iter_var_attrs;
/*! \brief The attachment type of the schedule */ /*! \brief The attachment type of the schedule */
AttachType attach_type{kNone}; AttachType attach_type{kNone};
/*! \brief The attach point of this schedule. */ /*! \brief The attach point of this schedule. */
...@@ -236,6 +272,7 @@ class StageNode : public Node { ...@@ -236,6 +272,7 @@ class StageNode : public Node {
v->Visit("all_iter_vars", &all_iter_vars); v->Visit("all_iter_vars", &all_iter_vars);
v->Visit("leaf_iter_vars", &leaf_iter_vars); v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("relations", &relations); v->Visit("relations", &relations);
v->Visit("iter_var_attrs", &iter_var_attrs);
v->Visit("attach_type", &attach_type); v->Visit("attach_type", &attach_type);
v->Visit("attach_ivar", &attach_ivar); v->Visit("attach_ivar", &attach_ivar);
v->Visit("attach_stage", &attach_stage); v->Visit("attach_stage", &attach_stage);
...@@ -268,6 +305,20 @@ class ScheduleNode : public Node { ...@@ -268,6 +305,20 @@ class ScheduleNode : public Node {
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode); TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode);
}; };
/*! \brief node container for IterVar attr */
class IterVarAttrNode : public Node {
public:
/*! \brief The iteration type. */
IterVarType iter_type;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter_type", &iter_type);
}
static constexpr const char* _type_key = "IterVarAttr";
TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode);
};
/*! \brief base node of iteration var */ /*! \brief base node of iteration var */
class IterVarRelationNode : public Node { class IterVarRelationNode : public Node {
}; };
...@@ -372,5 +423,9 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const { ...@@ -372,5 +423,9 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get()); return static_cast<const IterVarRelationNode*>(node_.get());
} }
inline const IterVarAttrNode* IterVarAttr::operator->() const {
return static_cast<const IterVarAttrNode*>(node_.get());
}
} // namespace tvm } // namespace tvm
#endif // TVM_SCHEDULE_H_ #endif // TVM_SCHEDULE_H_
...@@ -69,6 +69,7 @@ def build(sch, ...@@ -69,6 +69,7 @@ def build(sch,
stmt = schedule.ScheduleOps(sch, bounds) stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds) stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.CanonicalSimplify(stmt) stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step) stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt) stmt = ir_pass.Simplify(stmt)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list)) fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list))
......
...@@ -177,3 +177,23 @@ class Stage(NodeBase): ...@@ -177,3 +177,23 @@ class Stage(NodeBase):
x_outer, y_outer, x_inner, y_inner = _api_internal._StageTile( x_outer, y_outer, x_inner, y_inner = _api_internal._StageTile(
self, x_parent, y_parent, x_factor, y_factor) self, x_parent, y_parent, x_factor, y_factor)
return x_outer, y_outer, x_inner, y_inner return x_outer, y_outer, x_inner, y_inner
def vectorize(self, var):
"""Vectorize the iteration.
Parameters
----------
var : IterVar
The iteration to be vectorize
"""
_api_internal._StageVectorize(self, var)
def unroll(self, var):
"""Unroll the iteration.
Parameters
----------
var : IterVar
The iteration to be unrolled.
"""
_api_internal._StageUnroll(self, var)
...@@ -253,6 +253,18 @@ TVM_REGISTER_API(_StageTile) ...@@ -253,6 +253,18 @@ TVM_REGISTER_API(_StageTile)
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner}); *ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
}); });
TVM_REGISTER_API(_StageUnroll)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.unroll(args[1]);
});
TVM_REGISTER_API(_StageVectorize)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.vectorize(args[1]);
});
TVM_REGISTER_API(_ScheduleNormalize) TVM_REGISTER_API(_ScheduleNormalize)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Schedule() args[0].operator Schedule()
......
...@@ -62,6 +62,7 @@ REGISTER_PASS1(VerifySSA); ...@@ -62,6 +62,7 @@ REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(CanonicalSimplify); REGISTER_PASS1(CanonicalSimplify);
REGISTER_PASS4(Inline); REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten); REGISTER_PASS2(StorageFlatten);
REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS2(UnrollLoop); REGISTER_PASS2(UnrollLoop);
REGISTER_PASS2(StorageSync); REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI); REGISTER_PASS4(MakeAPI);
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <pass/Interval.h> #include <pass/Interval.h>
#include <limits>
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -52,6 +53,23 @@ inline bool GetConst<uint64_t>(Expr e, uint64_t *out) { ...@@ -52,6 +53,23 @@ inline bool GetConst<uint64_t>(Expr e, uint64_t *out) {
} }
} }
// get a small constant int
inline bool GetConstInt(Expr e, int* out) {
int64_t v1 = 0;
uint64_t v2 = 0;
if (GetConst(e, &v1)) {
if (v1 > static_cast<int64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v1); return true;
}
if (GetConst(e, &v2)) {
if (v2 > static_cast<uint64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v2); return true;
}
return false;
}
#define TVM_CONST_PROPAGATION(OP_NAME, OP) \ #define TVM_CONST_PROPAGATION(OP_NAME, OP) \
int64_t ia = 0, ib = 0; \ int64_t ia = 0, ib = 0; \
if (GetConst(a, &ia) && GetConst(b, &ib)) { \ if (GetConst(a, &ia) && GetConst(b, &ib)) { \
......
...@@ -102,6 +102,20 @@ class CodeGenC { ...@@ -102,6 +102,20 @@ class CodeGenC {
virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*) virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*) virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Select* op, std::ostream& os); // NOLINT(*) virtual void PrintExpr(const ir::Select* op, std::ostream& os); // NOLINT(*)
// Binary vector op.
virtual void PrintVecBinaryOp(
const std::string&op, Type op_type,
Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
virtual void PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os); // NOLINT(*)
virtual void PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value); // NOLINT(*)
virtual void PrintVecElemLoad(
const std::string& vec, Type t, int i, std::ostream& os); // NOLINT(*)
virtual void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value);
/*! \brief function print into the ostream */ /*! \brief function print into the ostream */
using FPrintExpr = IRFunctor<void(const NodeRef&, std::ostream& os, CodeGenC *)>; // NOLINT(*) using FPrintExpr = IRFunctor<void(const NodeRef&, std::ostream& os, CodeGenC *)>; // NOLINT(*)
/*! \brief function to to print normal code */ /*! \brief function to to print normal code */
...@@ -116,17 +130,10 @@ class CodeGenC { ...@@ -116,17 +130,10 @@ class CodeGenC {
std::ostringstream stream; std::ostringstream stream;
protected: protected:
// additional string for arg addr_space. // print reference to a buffer as type t in index.
std::string arg_addr_space_; void PrintBufferRef(const Variable* buffer,
Type t, Expr index,
private: std::ostream& os); // NOLINT(*)
/*! \brief entry in ssa assign map */
struct SSAEntry {
/*! \brief The value id */
std::string vid;
/*! \brief The scope id */
int scope_id;
};
/*! /*!
* \brief Get the SSA ID corresponds to src * \brief Get the SSA ID corresponds to src
* If necessary, generate new assignment * If necessary, generate new assignment
...@@ -135,6 +142,19 @@ class CodeGenC { ...@@ -135,6 +142,19 @@ class CodeGenC {
*/ */
std::string SSAGetID(std::string src, Type t); std::string SSAGetID(std::string src, Type t);
/*! /*!
* \brief get a unique name with the corresponding prefix
* \param prefix The prefix of the name
* \return The returned name.
*/
std::string GetUniqueName(std::string prefix);
/*! \brief entry in ssa assign map */
struct SSAEntry {
/*! \brief The value id */
std::string vid;
/*! \brief The scope id */
int scope_id;
};
/*!
* \brief mark the beginning of a new scope * \brief mark the beginning of a new scope
* \return The scope id. * \return The scope id.
*/ */
...@@ -155,25 +175,28 @@ class CodeGenC { ...@@ -155,25 +175,28 @@ class CodeGenC {
* \param buf_var The buffer variable. * \param buf_var The buffer variable.
* \param t The type to be checked. * \param t The type to be checked.
*/ */
void HandleTypeRegister(const Variable* buf_var, Type t); void RegisterHandleType(const Variable* buf_var, Type t);
/*! /*!
* \brief get a unique name with the corresponding prefix * \brief Get the storage scope of buf_var.
* \param prefix The prefix of the name * \param buf_var The buf_var to be queryed.
* \return The returned name. * \return The storage scope.
*/ */
std::string GetUniqueName(std::string prefix); std::string GetStorageScope(const Variable* buf_var) const;
/*! \brief whether to print in SSA form */
bool print_ssa_form_{true};
/*! \brief name of each variable */
std::unordered_map<const Variable*, std::string> var_idmap_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, Type> handle_data_type_;
/*! \brief the storage scope of allocation */ /*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, std::string> alloc_storage_scope_; std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
private:
/*! \brief whether to print in SSA form */
bool print_ssa_form_{true};
/*! \brief name allocation map */ /*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_; std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief assignment map of ssa */ /*! \brief assignment map of ssa */
std::unordered_map<std::string, SSAEntry> ssa_assign_map_; std::unordered_map<std::string, SSAEntry> ssa_assign_map_;
/*! \brief name of each variable */
std::unordered_map<const Variable*, std::string> var_idmap_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, Type> handle_data_type_;
/*! \brief array to check whether we are inside certain scope */ /*! \brief array to check whether we are inside certain scope */
std::vector<bool> scope_mark_; std::vector<bool> scope_mark_;
}; };
......
...@@ -22,6 +22,108 @@ std::string CodeGenCUDA::Compile( ...@@ -22,6 +22,108 @@ std::string CodeGenCUDA::Compile(
return CodeGenC::Compile(f, output_ssa); return CodeGenC::Compile(f, output_ssa);
} }
void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
CHECK_EQ(lanes, 1)
<< "do not yet support vector types";
os << "void*"; return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
case 16: os << "half"; break;
case 32: os << "float"; break;
case 64: os << "double"; break;
default: fail = true; break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes; return;
}
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
os << 'u';
}
if (t.bits() == 8 && t.lanes() == 4) {
// directly 4 8 bit int in integer.
os << "int"; return;
}
switch (t.bits()) {
case 8: os << "char"; break;
case 16: os << "short"; break;
case 32: os << "int"; break;
case 64: {
if (lanes != 1 && sizeof(long) == 64) { // NOLINT(*)
os << "long"; break;
} else {
os << "int64_t"; break;
}
}
case 1: os << "int"; break;
default: fail = true; break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes; return;
}
}
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
}
void CodeGenCUDA::PrintVecBinaryOp(
const std::string&op, Type t,
Expr lhs, Expr rhs, std::ostream& os) { // NOLINT(*)
// unpacking operations.
int lanes = t.lanes();
{
// default: unpack into individual ops.
std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.type());
std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.type());
std::string sret = GetUniqueName("_");
{
// delcare type.
this->PrintIndent();
this->PrintType(t, stream);
stream << ' ' << sret << ";\n";
}
for (int i = 0; i < lanes; ++i) {
std::ostringstream value_temp;
if (isalpha(op[0])) {
value_temp << op << "(";
PrintVecElemLoad(vlhs, lhs.type(), i, value_temp);
value_temp << ", ";
PrintVecElemLoad(vrhs, rhs.type(), i, value_temp);
value_temp << ")";
} else {
value_temp << "(";
PrintVecElemLoad(vlhs, lhs.type(), i, value_temp);
value_temp << op;
PrintVecElemLoad(vrhs, rhs.type(), i, value_temp);
value_temp << ")";
}
PrintVecElemStore(sret, t, i, value_temp.str());
}
os << sret;
}
}
void CodeGenCUDA::PrintVecElemLoad(
const std::string& vec, Type t, int i, std::ostream& os) { // NOLINT(*)
const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < 4);
os << vec << "." << access[i];
}
void CodeGenCUDA::PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value) {
this->PrintIndent();
const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < 4);
stream << vec << "." << access[i] << " = " << value << ";\n";
}
void CodeGenCUDA::PrintStorageSync(const std::string& sync) { void CodeGenCUDA::PrintStorageSync(const std::string& sync) {
if (sync == "shared") { if (sync == "shared") {
this->PrintIndent(); this->PrintIndent();
...@@ -43,8 +145,6 @@ void CodeGenCUDA::PrintStorageScope( ...@@ -43,8 +145,6 @@ void CodeGenCUDA::PrintStorageScope(
std::unordered_map<LoweredFunc, PackedFunc> std::unordered_map<LoweredFunc, PackedFunc>
MakeNVRTC(Array<LoweredFunc> funcs) { MakeNVRTC(Array<LoweredFunc> funcs) {
std::ostringstream os; std::ostringstream os;
os << "typedef int int32_t;\n"
<< "typedef unsigned unt32_t;\n";
bool output_ssa = false; bool output_ssa = false;
for (LoweredFunc f : funcs) { for (LoweredFunc f : funcs) {
os << CodeGenCUDA().Compile(f, output_ssa); os << CodeGenCUDA().Compile(f, output_ssa);
...@@ -56,6 +156,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) { ...@@ -56,6 +156,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc"); const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc");
code = f(code).operator std::string(); code = f(code).operator std::string();
} }
LOG(INFO) << code;
std::string ptx; std::string ptx;
if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) { if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile"); const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile");
......
...@@ -25,9 +25,18 @@ class CodeGenCUDA : public CodeGenC { ...@@ -25,9 +25,18 @@ class CodeGenCUDA : public CodeGenC {
*/ */
std::string Compile(LoweredFunc f, std::string Compile(LoweredFunc f,
bool output_ssa); bool output_ssa);
// override behavior // override behavior
void PrintStorageSync(const std::string& sync) final; void PrintStorageSync(const std::string& sync) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(
const std::string&op, Type t,
Expr lhs, Expr rhs, std::ostream& os) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
void PrintVecElemLoad(
const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value) final;
}; };
} // namespace codegen } // namespace codegen
......
...@@ -19,7 +19,11 @@ std::string CodeGenOpenCL::Compile( ...@@ -19,7 +19,11 @@ std::string CodeGenOpenCL::Compile(
LoweredFunc f, LoweredFunc f,
bool output_ssa) { bool output_ssa) {
this->stream << " __kernel "; this->stream << " __kernel ";
this->arg_addr_space_ = "__global "; for (Var arg : f->args) {
if (arg.type().is_handle()) {
alloc_storage_scope_[arg.get()] = "global";
}
}
return CodeGenC::Compile(f, output_ssa); return CodeGenC::Compile(f, output_ssa);
} }
...@@ -34,6 +38,80 @@ void CodeGenOpenCL::PrintThreadIndexExpr( ...@@ -34,6 +38,80 @@ void CodeGenOpenCL::PrintThreadIndexExpr(
} }
} }
void CodeGenOpenCL::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
CHECK_EQ(lanes, 1)
<< "do not yet support vector types";
os << "void*"; return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
case 16: os << "half"; break;
case 32: os << "float"; break;
case 64: os << "double"; break;
default: fail = true; break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
os << lanes; return;
}
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
os << 'u';
}
if (t.bits() == 8 && t.lanes() == 4) {
// directly 4 8 bit int in integer.
os << "int"; return;
}
switch (t.bits()) {
case 8: os << "char"; break;
case 16: os << "short"; break;
case 32: os << "int"; break;
case 64: os << "long"; break;
case 1: os << "int"; break;
default: fail = true; break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
os << lanes; return;
}
}
LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type";
}
void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, Type t,
Expr base, std::ostream& os) { // NOLINT(*)
if (!HandleTypeMatch(buffer, t.element_of())) {
os << '(';
auto it = alloc_storage_scope_.find(buffer);
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, os);
}
os << ' ';
PrintType(t.element_of(), os);
os << "*)";
}
os << GetVarID(buffer) << " + ";
PrintExpr(base, os);
}
void CodeGenOpenCL::PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os) {
os << "vload" << t.lanes() << "(0, ";
PrintVecAddr(buffer, t, base, os);
os << ")";
}
void CodeGenOpenCL::PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value) {
this->PrintIndent();
stream << "vstore" << t.lanes() << "(" << value << ", 0, ";
PrintVecAddr(buffer, t, base, stream);
stream << ");\n";
}
void CodeGenOpenCL::PrintStorageSync(const std::string& sync) { void CodeGenOpenCL::PrintStorageSync(const std::string& sync) {
if (sync == "shared") { if (sync == "shared") {
...@@ -45,8 +123,9 @@ void CodeGenOpenCL::PrintStorageSync(const std::string& sync) { ...@@ -45,8 +123,9 @@ void CodeGenOpenCL::PrintStorageSync(const std::string& sync) {
} }
void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
CHECK_NE(scope, "global"); if (scope == "global") {
if (scope == "shared") { os << "__global";
} else if (scope == "shared") {
os << "__local "; os << "__local ";
} }
} }
...@@ -55,8 +134,6 @@ void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os ...@@ -55,8 +134,6 @@ void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os
std::unordered_map<LoweredFunc, PackedFunc> std::unordered_map<LoweredFunc, PackedFunc>
MakeOpenCL(Array<LoweredFunc> funcs) { MakeOpenCL(Array<LoweredFunc> funcs) {
std::ostringstream os; std::ostringstream os;
os << "typedef int int32_t;\n"
<< "typedef unsigned unt32_t;\n";
bool output_ssa = false; bool output_ssa = false;
for (LoweredFunc f : funcs) { for (LoweredFunc f : funcs) {
os << CodeGenOpenCL().Compile(f, output_ssa); os << CodeGenOpenCL().Compile(f, output_ssa);
......
...@@ -30,6 +30,16 @@ class CodeGenOpenCL : public CodeGenC { ...@@ -30,6 +30,16 @@ class CodeGenOpenCL : public CodeGenC {
std::string tag, std::ostream& os) final; // NOLINT(*) std::string tag, std::ostream& os) final; // NOLINT(*)
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const std::string& scope) final; // NOLINT(*) void PrintStorageSync(const std::string& scope) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
void PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os) final; // NOLINT(*)
void PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value) final; // NOLINT(*)
// the address of load/store
void PrintVecAddr(const Variable* buffer, Type t,
Expr base, std::ostream& os); // NOLINT(*)
}; };
} // namespace codegen } // namespace codegen
......
...@@ -74,6 +74,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -74,6 +74,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.DISPATCH_TO_MUTATE_STMT(Provide) .DISPATCH_TO_MUTATE_STMT(Provide)
.DISPATCH_TO_MUTATE_STMT(Realize) .DISPATCH_TO_MUTATE_STMT(Realize)
.DISPATCH_TO_MUTATE_STMT(Store) .DISPATCH_TO_MUTATE_STMT(Store)
.DISPATCH_TO_MUTATE_STMT(IfThenElse)
.DISPATCH_TO_MUTATE_STMT(For) .DISPATCH_TO_MUTATE_STMT(For)
.DISPATCH_TO_MUTATE_STMT(Allocate) .DISPATCH_TO_MUTATE_STMT(Allocate)
.DISPATCH_TO_MUTATE_STMT(Free); .DISPATCH_TO_MUTATE_STMT(Free);
...@@ -195,6 +196,22 @@ Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) { ...@@ -195,6 +196,22 @@ Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) {
return s; return s;
} }
Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Stmt then_case = this->Mutate(op->then_case);
Stmt else_case;
if (else_case.defined()) {
else_case = this->Mutate(op->else_case);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
}
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Call) .DISPATCH_TO_MUTATE_EXPR(Call)
.DISPATCH_TO_MUTATE_EXPR(Let) .DISPATCH_TO_MUTATE_EXPR(Let)
...@@ -363,21 +380,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -363,21 +380,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
return Block::make(first, rest); return Block::make(first, rest);
} }
}) })
.set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) {
Expr condition = m->Mutate(op->condition);
Stmt then_case = m->Mutate(op->then_case);
Stmt else_case;
if (else_case.defined()) {
else_case = m->Mutate(op->else_case);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
})
.set_dispatch<Evaluate>([](const Evaluate *op, const Stmt& s, IRMutator* m) { .set_dispatch<Evaluate>([](const Evaluate *op, const Stmt& s, IRMutator* m) {
Expr v = m->Mutate(op->value); Expr v = m->Mutate(op->value);
if (v.same_as(op->value)) { if (v.same_as(op->value)) {
......
...@@ -101,9 +101,14 @@ class IRUseDefAnalysis : public IRMutator { ...@@ -101,9 +101,14 @@ class IRUseDefAnalysis : public IRMutator {
} }
void HandleDef(const Variable* v) { void HandleDef(const Variable* v) {
CHECK(!def_count_.count(v))
<< "variable " << v->name_hint
<< " has already been defined, the Stmt is not SSA";
CHECK(!use_count_.count(v)) CHECK(!use_count_.count(v))
<< "variable is already defined"; << "variable " << v->name_hint
<< " has been used before definition!";
use_count_[v] = 0; use_count_[v] = 0;
def_count_[v] = 1;
} }
void HandleUse(const Expr& v) { void HandleUse(const Expr& v) {
...@@ -127,6 +132,7 @@ class IRUseDefAnalysis : public IRMutator { ...@@ -127,6 +132,7 @@ class IRUseDefAnalysis : public IRMutator {
Array<IterVar> thread_axis_; Array<IterVar> thread_axis_;
Array<Expr> thread_extent_; Array<Expr> thread_extent_;
std::unordered_map<const Variable*, int> use_count_; std::unordered_map<const Variable*, int> use_count_;
std::unordered_map<const Variable*, int> def_count_;
}; };
class HostDeviceSplitter : public IRMutator { class HostDeviceSplitter : public IRMutator {
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2017 by Contributors
* SSA related checks and pass. * Loop unrolling.
* \file ssa.cc * \file unroll_loop.cc
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include <unordered_set> #include <unordered_set>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "../arithmetic//compute_expr.h" #include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -33,7 +33,8 @@ class LoopUnroller : public IRMutator { ...@@ -33,7 +33,8 @@ class LoopUnroller : public IRMutator {
if (v2 != nullptr) { if (v2 != nullptr) {
value = static_cast<int>(v2->value); value = static_cast<int>(v2->value);
} }
bool allow_unroll = value >= 0 && value <= max_auto_step_; bool allow_unroll = (op->for_type == ForType::Serial &&
value >= 0 && value <= max_auto_step_);
if (op->for_type == ForType::Unrolled) { if (op->for_type == ForType::Unrolled) {
CHECK_GE(value, 0) CHECK_GE(value, 0)
<< "Cannot unroll non-constant loop"; << "Cannot unroll non-constant loop";
......
...@@ -57,6 +57,14 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -57,6 +57,14 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<< ")"; << ")";
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IterVarAttrNode>([](const IterVarAttrNode *op, IRPrinter *p) {
switch (op->iter_type) {
case kUnrolled: p->stream << "unroll"; break;
case kVectorized: p->stream << "vectorize"; break;
}
});
Stage::Stage(Operation op) { Stage::Stage(Operation op) {
auto n = std::make_shared<StageNode>(); auto n = std::make_shared<StageNode>();
n->op = op; n->op = op;
...@@ -246,7 +254,38 @@ void Schedule::normalize() { ...@@ -246,7 +254,38 @@ void Schedule::normalize() {
} }
} }
IterVarAttr::IterVarAttr(IterVarType t) {
std::shared_ptr<IterVarAttrNode> n = std::make_shared<IterVarAttrNode>();
n->iter_type = t;
node_ = n;
}
inline void SetAttr(StageNode* self, IterVar var, IterVarAttr attr) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
FindLeafVar(all_vars, leaf_vars, var);
auto it = self->iter_var_attrs.find(var);
if (it != self->iter_var_attrs.end()) {
CHECK_EQ((*it).second->iter_type, attr->iter_type)
<< "IterVar's is already set to "
<< (*it).second << " instead of " << attr;
} else {
self->iter_var_attrs.Set(var, attr);
}
}
Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
SetAttr(operator->(), var, IterVarAttr(kVectorized));
return *this;
}
Stage& Stage::unroll(IterVar var) { // NOLINT(*)
SetAttr(operator->(), var, IterVarAttr(kUnrolled));
return *this;
}
TVM_REGISTER_NODE_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
TVM_REGISTER_NODE_TYPE(SplitNode); TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode); TVM_REGISTER_NODE_TYPE(FuseNode);
TVM_REGISTER_NODE_TYPE(RebaseNode); TVM_REGISTER_NODE_TYPE(RebaseNode);
......
...@@ -177,6 +177,13 @@ MakeLoopNest(const Stage& sch, ...@@ -177,6 +177,13 @@ MakeLoopNest(const Stage& sch,
} }
// Mark the iter var in the IR, to remember the point // Mark the iter var in the IR, to remember the point
if (iv->thread_tag.length() == 0) { if (iv->thread_tag.length() == 0) {
ForType for_type = ForType::Serial;
if (sch->iter_var_attrs.count(iv)) {
switch (sch->iter_var_attrs[iv]->iter_type) {
case kUnrolled: for_type = ForType::Unrolled; break;
case kVectorized: for_type = ForType::Vectorized; break;
}
}
if (is_one(dom->extent)) { if (is_one(dom->extent)) {
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
LetStmt::make(var, dom->min, no_op)); LetStmt::make(var, dom->min, no_op));
...@@ -184,13 +191,13 @@ MakeLoopNest(const Stage& sch, ...@@ -184,13 +191,13 @@ MakeLoopNest(const Stage& sch,
} else if (is_zero(dom->min)) { } else if (is_zero(dom->min)) {
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
For::make(var, 0, dom->extent, For::make(var, 0, dom->extent,
ForType::Serial, DeviceAPI::None, no_op)); for_type, DeviceAPI::None, no_op));
value_map[iv] = var; value_map[iv] = var;
} else { } else {
Var idx(iv->var->name_hint + ".idx", iv->var.type()); Var idx(iv->var->name_hint + ".idx", iv->var.type());
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
For::make(idx, 0, dom->extent, For::make(idx, 0, dom->extent,
ForType::Serial, DeviceAPI::None, no_op)); for_type, DeviceAPI::None, no_op));
Expr new_value = dom->min + idx; Expr new_value = dom->min + idx;
value_map[iv] = new_value; value_map[iv] = new_value;
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
......
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
def test_add(): def test_add():
# graph # graph
n = tvm.Var('n') n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
...@@ -13,26 +13,28 @@ def test_add(): ...@@ -13,26 +13,28 @@ def test_add():
num_thread = 256 num_thread = 256
block_x = tvm.IterVar(thread_tag="blockIdx.x") block_x = tvm.IterVar(thread_tag="blockIdx.x")
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
_, x = s[C].split(C.op.axis[0], factor=num_thread, outer=block_x) _, x = s[C].split(C.op.axis[0], factor=num_thread*4, outer=block_x)
_, x = s[C].split(x, outer=thread_x) _, x = s[C].split(x, outer=thread_x)
_, x = s[C].split(x, factor=4)
s[C].vectorize(x)
# one line to build the function. # one line to build the function.
def check_device(target):
codes = [] codes = []
fadd = tvm.build(s, fadd = tvm.build(s, [A, B, C],
args=[A, B, C], target, record_codes=codes,
target="cuda", name="myadd", name="myadd")
record_codes=codes) if target == "cuda":
for c in codes: ctx = tvm.gpu(0)
print(c) else:
ctx = tvm.cl(0)
# call the function
num_device = 1
for i in range(num_device):
ctx = tvm.gpu(i)
if not ctx.enabled: if not ctx.enabled:
continue return
for c in codes[1:]:
print(c)
# launch the kernel. # launch the kernel.
n = 1027 n = 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
...@@ -40,6 +42,10 @@ def test_add(): ...@@ -40,6 +42,10 @@ def test_add():
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy()) c.asnumpy(), a.asnumpy() + b.asnumpy())
tvm.init_opencl()
check_device("cuda")
check_device("opencl")
if __name__ == "__main__": if __name__ == "__main__":
test_add() test_add()
...@@ -76,6 +76,21 @@ def test_fuse(): ...@@ -76,6 +76,21 @@ def test_fuse():
assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations) assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi) assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi)
def test_vectorize():
m = tvm.Var('m')
n = tvm.Var('n')
A = tvm.placeholder((m, n), name='A')
T = tvm.compute((m, n), lambda i, j: A[i, j])
s = tvm.Schedule(T.op)
xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
s[T].vectorize(yi)
s[T].unroll(xi)
UNROLL = 1
VECTORIZE = 2
assert s[T].iter_var_attrs[xi].iter_type == UNROLL
assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_create() test_schedule_create()
...@@ -83,3 +98,4 @@ if __name__ == "__main__": ...@@ -83,3 +98,4 @@ if __name__ == "__main__":
test_tile() test_tile()
test_split() test_split()
test_fuse() test_fuse()
test_vectorize()
...@@ -9,11 +9,13 @@ def test_unroll_loop(): ...@@ -9,11 +9,13 @@ def test_unroll_loop():
# for i in 0 to n-1: # for i in 0 to n-1:
stmt = tvm.make.For( stmt = tvm.make.For(
i, n, 2, 0, 0, i, n, 2, 0, 0,
tvm.make.For(j, 0, n, 0, 0, tvm.make.For(j, 0, 8, 3, 0,
tvm.make.Store(Ab.data, tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1, tvm.make.Load(dtype, Ab.data, i) + 1,
j + 1))) j + 1)))
stmt = tvm.ir_pass.UnrollLoop(stmt, 8) assert isinstance(stmt, tvm.stmt.For)
stmt = tvm.ir_pass.UnrollLoop(stmt, 4)
assert not isinstance(stmt, tvm.stmt.For)
print(stmt) print(stmt)
if __name__ == "__main__": if __name__ == "__main__":
......
import tvm
def test_vectorize_loop():
dtype = 'int64'
n = tvm.Var('n')
Ab = tvm.Buffer((n, ), dtype)
i = tvm.Var('i')
j = tvm.Var('j')
VECTORIZE = 2
# for i in 0 to n-1:
stmt = tvm.make.For(
i, n, 2, 0, 0,
tvm.make.For(j, 0, 4, VECTORIZE, 0,
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1,
j + 1)))
assert isinstance(stmt.body, tvm.stmt.For)
stmt = tvm.ir_pass.VectorizeLoop(stmt)
assert isinstance(stmt, tvm.stmt.For)
assert not isinstance(stmt.body, tvm.stmt.For)
print(stmt)
if __name__ == "__main__":
test_vectorize_loop()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment