Commit 45597d00 by Tianqi Chen Committed by GitHub

[LANG/PASS] Support Vectorize (#37)

parent 6a62beb2
...@@ -62,6 +62,7 @@ class IRMutator { ...@@ -62,6 +62,7 @@ class IRMutator {
virtual Stmt Mutate_(const Realize* op, const Stmt& s); virtual Stmt Mutate_(const Realize* op, const Stmt& s);
virtual Stmt Mutate_(const Store* op, const Stmt& s); virtual Stmt Mutate_(const Store* op, const Stmt& s);
virtual Stmt Mutate_(const Free* op, const Stmt& s); virtual Stmt Mutate_(const Free* op, const Stmt& s);
virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
virtual Expr Mutate_(const Call* op, const Expr& e); virtual Expr Mutate_(const Call* op, const Expr& e);
virtual Expr Mutate_(const Load* op, const Expr& s); virtual Expr Mutate_(const Load* op, const Expr& s);
virtual Expr Mutate_(const Variable* op, const Expr& e); virtual Expr Mutate_(const Variable* op, const Expr& e);
......
...@@ -112,6 +112,12 @@ Stmt StorageFlatten(Stmt stmt, ...@@ -112,6 +112,12 @@ Stmt StorageFlatten(Stmt stmt,
Stmt UnrollLoop(Stmt stmt, int max_auto_step); Stmt UnrollLoop(Stmt stmt, int max_auto_step);
/*! /*!
* \brief vectorize the constant loops
* \param stmt The statment to be vectorized.
*/
Stmt VectorizeLoop(Stmt stmt);
/*!
* \brief Make an user callable API LoweredFunc. * \brief Make an user callable API LoweredFunc.
* *
* The main task of this function is to create code to : * The main task of this function is to create code to :
......
...@@ -18,6 +18,8 @@ class StageNode; ...@@ -18,6 +18,8 @@ class StageNode;
class ScheduleNode; class ScheduleNode;
// Node container for IterVarRelation // Node container for IterVarRelation
class IterVarRelationNode; class IterVarRelationNode;
// Attribute of itervar.
class IterVarAttrNode;
/*! \brief the attachment type */ /*! \brief the attachment type */
enum AttachType : int { enum AttachType : int {
...@@ -27,6 +29,12 @@ enum AttachType : int { ...@@ -27,6 +29,12 @@ enum AttachType : int {
kScope = 3 kScope = 3
}; };
/*! \brief IterVar type */
enum IterVarType : int {
kUnrolled = 1,
kVectorized = 2
};
/*! \brief Stage, contains scheduling for a stage of computation. */ /*! \brief Stage, contains scheduling for a stage of computation. */
class Stage : public NodeRef { class Stage : public NodeRef {
public: public:
...@@ -124,11 +132,22 @@ class Stage : public NodeRef { ...@@ -124,11 +132,22 @@ class Stage : public NodeRef {
IterVar* p_x_inner, IterVar* p_y_inner, IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor); Expr x_factor, Expr y_factor);
/*! /*!
* \brief Vectorize iteration.
* \param var The axis to be vectorized.
* \return reference to self.
*/
Stage& vectorize(IterVar var); // NOLINT(*)
/*!
* \brief Unroll iteration.
* \param var The axis to be vectorized.
* \return reference to self.
*/
Stage& unroll(IterVar var); // NOLINT(*)
/*!
* \brief whether the stage has been scheduled. * \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled. * \return whether the stage has been scheduled.
*/ */
inline bool is_scheduled() const; inline bool is_scheduled() const;
// declare container type // declare container type
using ContainerType = StageNode; using ContainerType = StageNode;
}; };
...@@ -193,6 +212,21 @@ class IterVarRelation : public NodeRef { ...@@ -193,6 +212,21 @@ class IterVarRelation : public NodeRef {
inline const IterVarRelationNode* operator->() const; inline const IterVarRelationNode* operator->() const;
}; };
/*!
* \brief Additional scheduable attributes about IterVar.
*/
class IterVarAttr : public NodeRef {
public:
IterVarAttr() {}
explicit IterVarAttr(IterVarType t);
explicit IterVarAttr(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const IterVarAttrNode* operator->() const;
};
// defintion of node containers // defintion of node containers
/*! /*!
* \brief represents the schedule of the tensor * \brief represents the schedule of the tensor
...@@ -223,6 +257,8 @@ class StageNode : public Node { ...@@ -223,6 +257,8 @@ class StageNode : public Node {
Array<IterVar> leaf_iter_vars; Array<IterVar> leaf_iter_vars;
/*! \brief The relation bwteen of IterVars */ /*! \brief The relation bwteen of IterVars */
Array<IterVarRelation> relations; Array<IterVarRelation> relations;
/*! \brief additional attributes about iter var. */
Map<IterVar, IterVarAttr> iter_var_attrs;
/*! \brief The attachment type of the schedule */ /*! \brief The attachment type of the schedule */
AttachType attach_type{kNone}; AttachType attach_type{kNone};
/*! \brief The attach point of this schedule. */ /*! \brief The attach point of this schedule. */
...@@ -236,6 +272,7 @@ class StageNode : public Node { ...@@ -236,6 +272,7 @@ class StageNode : public Node {
v->Visit("all_iter_vars", &all_iter_vars); v->Visit("all_iter_vars", &all_iter_vars);
v->Visit("leaf_iter_vars", &leaf_iter_vars); v->Visit("leaf_iter_vars", &leaf_iter_vars);
v->Visit("relations", &relations); v->Visit("relations", &relations);
v->Visit("iter_var_attrs", &iter_var_attrs);
v->Visit("attach_type", &attach_type); v->Visit("attach_type", &attach_type);
v->Visit("attach_ivar", &attach_ivar); v->Visit("attach_ivar", &attach_ivar);
v->Visit("attach_stage", &attach_stage); v->Visit("attach_stage", &attach_stage);
...@@ -268,6 +305,20 @@ class ScheduleNode : public Node { ...@@ -268,6 +305,20 @@ class ScheduleNode : public Node {
TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode); TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode);
}; };
/*! \brief node container for IterVar attr */
class IterVarAttrNode : public Node {
public:
/*! \brief The iteration type. */
IterVarType iter_type;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("iter_type", &iter_type);
}
static constexpr const char* _type_key = "IterVarAttr";
TVM_DECLARE_NODE_TYPE_INFO(IterVarAttrNode);
};
/*! \brief base node of iteration var */ /*! \brief base node of iteration var */
class IterVarRelationNode : public Node { class IterVarRelationNode : public Node {
}; };
...@@ -372,5 +423,9 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const { ...@@ -372,5 +423,9 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
return static_cast<const IterVarRelationNode*>(node_.get()); return static_cast<const IterVarRelationNode*>(node_.get());
} }
inline const IterVarAttrNode* IterVarAttr::operator->() const {
return static_cast<const IterVarAttrNode*>(node_.get());
}
} // namespace tvm } // namespace tvm
#endif // TVM_SCHEDULE_H_ #endif // TVM_SCHEDULE_H_
...@@ -69,6 +69,7 @@ def build(sch, ...@@ -69,6 +69,7 @@ def build(sch,
stmt = schedule.ScheduleOps(sch, bounds) stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds) stmt = ir_pass.StorageFlatten(stmt, binds)
stmt = ir_pass.CanonicalSimplify(stmt) stmt = ir_pass.CanonicalSimplify(stmt)
stmt = ir_pass.VectorizeLoop(stmt)
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step) stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt) stmt = ir_pass.Simplify(stmt)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list)) fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list))
......
...@@ -177,3 +177,23 @@ class Stage(NodeBase): ...@@ -177,3 +177,23 @@ class Stage(NodeBase):
x_outer, y_outer, x_inner, y_inner = _api_internal._StageTile( x_outer, y_outer, x_inner, y_inner = _api_internal._StageTile(
self, x_parent, y_parent, x_factor, y_factor) self, x_parent, y_parent, x_factor, y_factor)
return x_outer, y_outer, x_inner, y_inner return x_outer, y_outer, x_inner, y_inner
def vectorize(self, var):
"""Vectorize the iteration.
Parameters
----------
var : IterVar
The iteration to be vectorize
"""
_api_internal._StageVectorize(self, var)
def unroll(self, var):
"""Unroll the iteration.
Parameters
----------
var : IterVar
The iteration to be unrolled.
"""
_api_internal._StageUnroll(self, var)
...@@ -253,6 +253,18 @@ TVM_REGISTER_API(_StageTile) ...@@ -253,6 +253,18 @@ TVM_REGISTER_API(_StageTile)
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner}); *ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
}); });
TVM_REGISTER_API(_StageUnroll)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.unroll(args[1]);
});
TVM_REGISTER_API(_StageVectorize)
.set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Stage()
.vectorize(args[1]);
});
TVM_REGISTER_API(_ScheduleNormalize) TVM_REGISTER_API(_ScheduleNormalize)
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
args[0].operator Schedule() args[0].operator Schedule()
......
...@@ -62,6 +62,7 @@ REGISTER_PASS1(VerifySSA); ...@@ -62,6 +62,7 @@ REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(CanonicalSimplify); REGISTER_PASS1(CanonicalSimplify);
REGISTER_PASS4(Inline); REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten); REGISTER_PASS2(StorageFlatten);
REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS2(UnrollLoop); REGISTER_PASS2(UnrollLoop);
REGISTER_PASS2(StorageSync); REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI); REGISTER_PASS4(MakeAPI);
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <pass/Interval.h> #include <pass/Interval.h>
#include <limits>
namespace tvm { namespace tvm {
namespace arith { namespace arith {
...@@ -52,6 +53,23 @@ inline bool GetConst<uint64_t>(Expr e, uint64_t *out) { ...@@ -52,6 +53,23 @@ inline bool GetConst<uint64_t>(Expr e, uint64_t *out) {
} }
} }
// get a small constant int
inline bool GetConstInt(Expr e, int* out) {
int64_t v1 = 0;
uint64_t v2 = 0;
if (GetConst(e, &v1)) {
if (v1 > static_cast<int64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v1); return true;
}
if (GetConst(e, &v2)) {
if (v2 > static_cast<uint64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v2); return true;
}
return false;
}
#define TVM_CONST_PROPAGATION(OP_NAME, OP) \ #define TVM_CONST_PROPAGATION(OP_NAME, OP) \
int64_t ia = 0, ib = 0; \ int64_t ia = 0, ib = 0; \
if (GetConst(a, &ia) && GetConst(b, &ib)) { \ if (GetConst(a, &ia) && GetConst(b, &ib)) { \
......
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
* \file codegen_c.cc * \file codegen_c.cc
*/ */
#include <iomanip> #include <iomanip>
#include <cctype>
#include "./codegen_c.h" #include "./codegen_c.h"
#include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -14,10 +16,10 @@ std::string CodeGenC::Compile(LoweredFunc f, ...@@ -14,10 +16,10 @@ std::string CodeGenC::Compile(LoweredFunc f,
bool output_ssa) { bool output_ssa) {
print_ssa_form_ = output_ssa; print_ssa_form_ = output_ssa;
// skip the first underscore, so SSA variable starts from _1 // skip the first underscore, so SSA variable starts from _1
if (print_ssa_form_) GetUniqueName("_"); GetUniqueName("_");
// add to alloc buffer type. // add to alloc buffer type.
for (const auto & kv : f->handle_data_type) { for (const auto & kv : f->handle_data_type) {
HandleTypeRegister(kv.first.get(), kv.second.type()); RegisterHandleType(kv.first.get(), kv.second.type());
} }
this->stream << "void " << f->name << "("; this->stream << "void " << f->name << "(";
...@@ -26,7 +28,11 @@ std::string CodeGenC::Compile(LoweredFunc f, ...@@ -26,7 +28,11 @@ std::string CodeGenC::Compile(LoweredFunc f,
std::string vid = AllocVarID(v.get()); std::string vid = AllocVarID(v.get());
if (i != 0) stream << ", "; if (i != 0) stream << ", ";
if (v.type().is_handle()) { if (v.type().is_handle()) {
stream << arg_addr_space_; auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
}
stream << ' ';
} }
if (handle_data_type_.count(v.get())) { if (handle_data_type_.count(v.get())) {
PrintType(handle_data_type_.at(v.get()), stream); PrintType(handle_data_type_.at(v.get()), stream);
...@@ -126,7 +132,7 @@ bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const { ...@@ -126,7 +132,7 @@ bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const {
return it->second == t; return it->second == t;
} }
void CodeGenC::HandleTypeRegister(const Variable* buf_var, Type t) { void CodeGenC::RegisterHandleType(const Variable* buf_var, Type t) {
auto it = handle_data_type_.find(buf_var); auto it = handle_data_type_.find(buf_var);
if (it == handle_data_type_.end()) { if (it == handle_data_type_.end()) {
handle_data_type_[buf_var] = t; handle_data_type_[buf_var] = t;
...@@ -259,23 +265,39 @@ inline void PrintBinaryExpr(const T* op, ...@@ -259,23 +265,39 @@ inline void PrintBinaryExpr(const T* op,
const char *opstr, const char *opstr,
std::ostream& os, // NOLINT(*) std::ostream& os, // NOLINT(*)
CodeGenC* p) { CodeGenC* p) {
os << '('; if (op->type.lanes() == 1) {
p->PrintExpr(op->a, os); if (isalpha(opstr[0])) {
os << opstr; os << opstr << '(';
p->PrintExpr(op->b, os); p->PrintExpr(op->a, os);
os << ')'; os << ", ";
p->PrintExpr(op->b, os);
os << ')';
} else {
os << '(';
p->PrintExpr(op->a, os);
os << ' ' << opstr << ' ';
p->PrintExpr(op->b, os);
os << ')';
}
} else {
p->PrintVecBinaryOp(opstr, op->type, op->a, op->b, os);
}
} }
inline void PrintBinaryIntrinsitc(const Call* op, inline void PrintBinaryIntrinsitc(const Call* op,
const char *opstr, const char *opstr,
std::ostream& os, // NOLINT(*) std::ostream& os, // NOLINT(*)
CodeGenC* p) { CodeGenC* p) {
CHECK_EQ(op->args.size(), 2U); if (op->type.lanes() == 1) {
os << '('; CHECK_EQ(op->args.size(), 2U);
p->PrintExpr(op->args[0], os); os << '(';
os << opstr; p->PrintExpr(op->args[0], os);
p->PrintExpr(op->args[1], os); os << opstr;
os << ')'; p->PrintExpr(op->args[1], os);
os << ')';
} else {
p->PrintVecBinaryOp(opstr, op->type, op->args[0], op->args[1], os);
}
} }
TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
...@@ -289,57 +311,49 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) ...@@ -289,57 +311,49 @@ TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr)
os << p->GetVarID(op); os << p->GetVarID(op);
}) })
.set_dispatch<Add>([](const Add *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) .set_dispatch<Add>([](const Add *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " + ", os, p); PrintBinaryExpr(op, "+", os, p);
}) })
.set_dispatch<Sub>([](const Sub *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) .set_dispatch<Sub>([](const Sub *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " - ", os, p); PrintBinaryExpr(op, "-", os, p);
}) })
.set_dispatch<Mul>([](const Mul *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) .set_dispatch<Mul>([](const Mul *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " * ", os, p); PrintBinaryExpr(op, "*", os, p);
}) })
.set_dispatch<Div>([](const Div *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) .set_dispatch<Div>([](const Div *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " / ", os, p); PrintBinaryExpr(op, "/", os, p);
}) })
.set_dispatch<Mod>([](const Mod *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) .set_dispatch<Mod>([](const Mod *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " % ", os, p); PrintBinaryExpr(op, "%", os, p);
}) })
.set_dispatch<Min>([](const Min *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) .set_dispatch<Min>([](const Min *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
os << "min("; PrintBinaryExpr(op, "min", os, p);
p->PrintExpr(op->a, os);
os << ", ";
p->PrintExpr(op->b, os);
os << ")";
}) })
.set_dispatch<Max>([](const Max *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) .set_dispatch<Max>([](const Max *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
os << "max("; PrintBinaryExpr(op, "max", os, p);
p->PrintExpr(op->a, os);
os << ", ";
p->PrintExpr(op->b, os);
os << ")";
}) })
.set_dispatch<EQ>([](const EQ *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) .set_dispatch<EQ>([](const EQ *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " == ", os, p); PrintBinaryExpr(op, "==", os, p);
}) })
.set_dispatch<NE>([](const NE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) .set_dispatch<NE>([](const NE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " != ", os, p); PrintBinaryExpr(op, "!=", os, p);
}) })
.set_dispatch<LT>([](const LT *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) .set_dispatch<LT>([](const LT *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " < ", os, p); PrintBinaryExpr(op, "<", os, p);
}) })
.set_dispatch<LE>([](const LE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) .set_dispatch<LE>([](const LE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " <= ", os, p); PrintBinaryExpr(op, "<=", os, p);
}) })
.set_dispatch<GT>([](const GT *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) .set_dispatch<GT>([](const GT *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " > ", os, p); PrintBinaryExpr(op, ">", os, p);
}) })
.set_dispatch<GE>([](const GE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) .set_dispatch<GE>([](const GE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " >= ", os, p); PrintBinaryExpr(op, ">=", os, p);
}) })
.set_dispatch<And>([](const And *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) .set_dispatch<And>([](const And *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " && ", os, p); PrintBinaryExpr(op, "&&", os, p);
}) })
.set_dispatch<Or>([](const Or *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) .set_dispatch<Or>([](const Or *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
PrintBinaryExpr(op, " || ", os, p); PrintBinaryExpr(op, "||", os, p);
}) })
.set_dispatch<Not>([](const Not *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) .set_dispatch<Not>([](const Not *op, std::ostream& os, CodeGenC* p) { // NOLINT(*)
os << '!'; os << '!';
...@@ -460,18 +474,179 @@ void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*) ...@@ -460,18 +474,179 @@ void CodeGenC::PrintExpr(const Call *op, std::ostream& os) { // NOLINT(*)
} }
} }
void CodeGenC::PrintVecBinaryOp(
const std::string&op, Type t,
Expr lhs, Expr rhs, std::ostream& os) { // NOLINT(*)
if (isalpha(op[0])) {
os << op << "(";
this->PrintExpr(lhs, os);
os << ", ";
this->PrintExpr(rhs, os);
os << ")";
} else {
os <<"(";
this->PrintExpr(lhs, os);
os << ' ' << op << ' ';
this->PrintExpr(rhs, os);
os << ")";
}
}
inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) {
const Ramp* r = index.as<Ramp>();
if (!r) return false;
if (!is_one(r->stride)) return false;
CHECK_EQ(r->lanes, lanes);
*base = r->base;
return true;
}
// Print a reference expression to a buffer.
void CodeGenC::PrintBufferRef(
const Variable* buffer,
Type t, Expr index,
std::ostream& os) { // NOLINT(*)
std::string vid = GetVarID(buffer);
if (t.lanes() == 1) {
if (!HandleTypeMatch(buffer, t)) {
os << "((";
PrintType(t, os);
os << "*)" << vid << ')';
} else {
os << vid;
}
os << '[';
PrintExpr(index, os);
os << ']';
} else {
// Buffer declared as vector type.
// optimize for case where it is in register,
if (HandleTypeMatch(buffer, t)) {
// optimize for constant access
int offset;
if (arith::GetConstInt(index, &offset)) {
CHECK_EQ(offset % t.lanes(), 0)
<< "Find unaligned vector load to a vector type";
os << vid << '[' << (offset / t.lanes()) << ']';
return;
}
}
os << "((";
PrintType(t, os);
os << "*)(";
if (!HandleTypeMatch(buffer, t.element_of())) {
os << '(';
PrintType(t.element_of(), os);
os << "*)";
}
os << vid << " + ";
PrintExpr(index, os);
os << "))[0]";
}
}
void CodeGenC::PrintExpr(const Load* op, std::ostream& os) { // NOLINT(*) void CodeGenC::PrintExpr(const Load* op, std::ostream& os) { // NOLINT(*)
std::string vid = GetVarID(op->buffer_var.get()); int lanes = op->type.lanes();
if (!HandleTypeMatch(op->buffer_var.get(), op->type)) { if (op->type.lanes() == 1) {
os << "((const "; this->PrintBufferRef(op->buffer_var.get(), op->type, op->index, os);
PrintType(op->type, os);
os << "*)" << vid << ')';
} else { } else {
os << vid; Expr base;
if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) {
this->PrintVecLoad(op->buffer_var.get(), op->type, base, os);
} else {
// Load elements seperately
std::string sindex = SSAGetID(PrintExpr(op->index), op->index.type());
std::string svalue = GetUniqueName("_");
{
// delcare type.
this->PrintIndent();
this->PrintType(op->type, stream);
stream << ' ' << svalue << ";\n";
}
std::string vid = GetVarID(op->buffer_var.get());
Type elem_type = op->type.element_of();
for (int i = 0; i < lanes; ++i) {
std::ostringstream value_temp;
if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) {
value_temp << "((";
PrintType(elem_type, os);
value_temp << "*)" << vid << ')';
} else {
value_temp << vid;
}
value_temp << '[';
PrintVecElemLoad(sindex, op->index.type(), i, value_temp);
value_temp << ']';
PrintVecElemStore(svalue, op->type, i, value_temp.str());
}
os << svalue;
}
} }
os << '['; }
PrintExpr(op->index, os);
os << ']'; void CodeGenC::PrintStmt(const Store* op) {
Type t = op->value.type();
if (t.lanes() == 1) {
this->PrintIndent();
std::string value = this->PrintExpr(op->value);
this->PrintBufferRef(op->buffer_var.get(), t, op->index, stream);
stream << " = " << value << ";\n";
} else {
Expr base;
if (TryGetRamp1Base(op->index, t.lanes(), &base)) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer_var.get(), t, base, value);
} else {
// store elements seperately
std::string index = SSAGetID(PrintExpr(op->index), op->index.type());
std::string value = SSAGetID(PrintExpr(op->value), op->value.type());
std::string vid = GetVarID(op->buffer_var.get());
for (int i = 0; i < t.lanes(); ++i) {
this->PrintIndent();
Type elem_type = t.element_of();
if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) {
stream << "((";
PrintType(elem_type, stream);
stream << "*)" << vid << ')';
} else {
stream << vid;
}
stream << '[';
PrintVecElemLoad(index, op->index.type(), i, stream);
stream << "] = ";
PrintVecElemLoad(value, op->value.type(), i, stream);
stream << ";\n";
}
}
}
}
void CodeGenC::PrintVecElemLoad(const std::string& vec,
Type t, int i,
std::ostream& os) { // NOLINT(*)
os << vec << ".s" << std::hex << i;
}
void CodeGenC::PrintVecElemStore(const std::string& vec,
Type t, int i,
const std::string& value) {
this->PrintIndent();
stream << vec << ".s" << std::hex << i
<< " = " << value << ";\n";
}
void CodeGenC::PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os) {
PrintBufferRef(buffer, t, base, os);
}
void CodeGenC::PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value) {
this->PrintIndent();
PrintBufferRef(buffer, t, base, stream);
stream << " = " << value << ";\n";
} }
void CodeGenC::PrintExpr(const Let* op, std::ostream& os) { // NOLINT(*) void CodeGenC::PrintExpr(const Let* op, std::ostream& os) { // NOLINT(*)
...@@ -483,15 +658,15 @@ void CodeGenC::PrintExpr(const Let* op, std::ostream& os) { // NOLINT(*) ...@@ -483,15 +658,15 @@ void CodeGenC::PrintExpr(const Let* op, std::ostream& os) { // NOLINT(*)
} }
void CodeGenC::PrintExpr(const Ramp* op, std::ostream& os) { // NOLINT(*) void CodeGenC::PrintExpr(const Ramp* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "not supported "; LOG(FATAL) << "Ramp: not supported ";
} }
void CodeGenC::PrintExpr(const Broadcast* op, std::ostream& os) { // NOLINT(*) void CodeGenC::PrintExpr(const Broadcast* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "not supported "; LOG(FATAL) << "Broadcast: not supported ";
} }
void CodeGenC::PrintExpr(const Select* op, std::ostream& os) { // NOLINT(*) void CodeGenC::PrintExpr(const Select* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "not supported "; LOG(FATAL) << "Select: not supported ";
} }
// Disoatch back to member functions // Disoatch back to member functions
...@@ -541,23 +716,6 @@ void CodeGenC::PrintStmt(const LetStmt* op) { ...@@ -541,23 +716,6 @@ void CodeGenC::PrintStmt(const LetStmt* op) {
PrintStmt(op->body); PrintStmt(op->body);
} }
void CodeGenC::PrintStmt(const Store* op) {
std::string index = this->PrintExpr(op->index);
std::string value = this->PrintExpr(op->value);
this->PrintIndent();
std::string vid = GetVarID(op->buffer_var.get());
if (!HandleTypeMatch(op->buffer_var.get(), op->value.type())) {
this->stream << "((";
PrintType(op->value.type(), this->stream);
this->stream << "*)" << vid << ')';
} else {
this->stream << vid;
}
this->stream << '[' << index
<< "] = " << value
<< ";\n";
}
void CodeGenC::PrintStmt(const Allocate* op) { void CodeGenC::PrintStmt(const Allocate* op) {
CHECK(!is_zero(op->condition)); CHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get()); std::string vid = AllocVarID(op->buffer_var.get());
...@@ -580,7 +738,7 @@ void CodeGenC::PrintStmt(const Allocate* op) { ...@@ -580,7 +738,7 @@ void CodeGenC::PrintStmt(const Allocate* op) {
stream << ' '<< vid << '[' stream << ' '<< vid << '['
<< constant_size << "];\n"; << constant_size << "];\n";
} }
HandleTypeRegister(op->buffer_var.get(), op->type); RegisterHandleType(op->buffer_var.get(), op->type);
this->PrintStmt(op->body); this->PrintStmt(op->body);
} }
......
...@@ -102,6 +102,20 @@ class CodeGenC { ...@@ -102,6 +102,20 @@ class CodeGenC {
virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*) virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*) virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*)
virtual void PrintExpr(const ir::Select* op, std::ostream& os); // NOLINT(*) virtual void PrintExpr(const ir::Select* op, std::ostream& os); // NOLINT(*)
// Binary vector op.
virtual void PrintVecBinaryOp(
const std::string&op, Type op_type,
Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
virtual void PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os); // NOLINT(*)
virtual void PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value); // NOLINT(*)
virtual void PrintVecElemLoad(
const std::string& vec, Type t, int i, std::ostream& os); // NOLINT(*)
virtual void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value);
/*! \brief function print into the ostream */ /*! \brief function print into the ostream */
using FPrintExpr = IRFunctor<void(const NodeRef&, std::ostream& os, CodeGenC *)>; // NOLINT(*) using FPrintExpr = IRFunctor<void(const NodeRef&, std::ostream& os, CodeGenC *)>; // NOLINT(*)
/*! \brief function to to print normal code */ /*! \brief function to to print normal code */
...@@ -116,17 +130,10 @@ class CodeGenC { ...@@ -116,17 +130,10 @@ class CodeGenC {
std::ostringstream stream; std::ostringstream stream;
protected: protected:
// additional string for arg addr_space. // print reference to a buffer as type t in index.
std::string arg_addr_space_; void PrintBufferRef(const Variable* buffer,
Type t, Expr index,
private: std::ostream& os); // NOLINT(*)
/*! \brief entry in ssa assign map */
struct SSAEntry {
/*! \brief The value id */
std::string vid;
/*! \brief The scope id */
int scope_id;
};
/*! /*!
* \brief Get the SSA ID corresponds to src * \brief Get the SSA ID corresponds to src
* If necessary, generate new assignment * If necessary, generate new assignment
...@@ -135,6 +142,19 @@ class CodeGenC { ...@@ -135,6 +142,19 @@ class CodeGenC {
*/ */
std::string SSAGetID(std::string src, Type t); std::string SSAGetID(std::string src, Type t);
/*! /*!
* \brief get a unique name with the corresponding prefix
* \param prefix The prefix of the name
* \return The returned name.
*/
std::string GetUniqueName(std::string prefix);
/*! \brief entry in ssa assign map */
struct SSAEntry {
/*! \brief The value id */
std::string vid;
/*! \brief The scope id */
int scope_id;
};
/*!
* \brief mark the beginning of a new scope * \brief mark the beginning of a new scope
* \return The scope id. * \return The scope id.
*/ */
...@@ -155,25 +175,28 @@ class CodeGenC { ...@@ -155,25 +175,28 @@ class CodeGenC {
* \param buf_var The buffer variable. * \param buf_var The buffer variable.
* \param t The type to be checked. * \param t The type to be checked.
*/ */
void HandleTypeRegister(const Variable* buf_var, Type t); void RegisterHandleType(const Variable* buf_var, Type t);
/*! /*!
* \brief get a unique name with the corresponding prefix * \brief Get the storage scope of buf_var.
* \param prefix The prefix of the name * \param buf_var The buf_var to be queryed.
* \return The returned name. * \return The storage scope.
*/ */
std::string GetUniqueName(std::string prefix); std::string GetStorageScope(const Variable* buf_var) const;
/*! \brief whether to print in SSA form */
bool print_ssa_form_{true};
/*! \brief name of each variable */
std::unordered_map<const Variable*, std::string> var_idmap_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, Type> handle_data_type_;
/*! \brief the storage scope of allocation */ /*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, std::string> alloc_storage_scope_; std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
private:
/*! \brief whether to print in SSA form */
bool print_ssa_form_{true};
/*! \brief name allocation map */ /*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_; std::unordered_map<std::string, int> name_alloc_map_;
/*! \brief assignment map of ssa */ /*! \brief assignment map of ssa */
std::unordered_map<std::string, SSAEntry> ssa_assign_map_; std::unordered_map<std::string, SSAEntry> ssa_assign_map_;
/*! \brief name of each variable */
std::unordered_map<const Variable*, std::string> var_idmap_;
/*! \brief the data type of allocated buffers */
std::unordered_map<const Variable*, Type> handle_data_type_;
/*! \brief array to check whether we are inside certain scope */ /*! \brief array to check whether we are inside certain scope */
std::vector<bool> scope_mark_; std::vector<bool> scope_mark_;
}; };
......
...@@ -22,6 +22,108 @@ std::string CodeGenCUDA::Compile( ...@@ -22,6 +22,108 @@ std::string CodeGenCUDA::Compile(
return CodeGenC::Compile(f, output_ssa); return CodeGenC::Compile(f, output_ssa);
} }
void CodeGenCUDA::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
CHECK_EQ(lanes, 1)
<< "do not yet support vector types";
os << "void*"; return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
case 16: os << "half"; break;
case 32: os << "float"; break;
case 64: os << "double"; break;
default: fail = true; break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes; return;
}
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
os << 'u';
}
if (t.bits() == 8 && t.lanes() == 4) {
// directly 4 8 bit int in integer.
os << "int"; return;
}
switch (t.bits()) {
case 8: os << "char"; break;
case 16: os << "short"; break;
case 32: os << "int"; break;
case 64: {
if (lanes != 1 && sizeof(long) == 64) { // NOLINT(*)
os << "long"; break;
} else {
os << "int64_t"; break;
}
}
case 1: os << "int"; break;
default: fail = true; break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes; return;
}
}
LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
}
void CodeGenCUDA::PrintVecBinaryOp(
const std::string&op, Type t,
Expr lhs, Expr rhs, std::ostream& os) { // NOLINT(*)
// unpacking operations.
int lanes = t.lanes();
{
// default: unpack into individual ops.
std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.type());
std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.type());
std::string sret = GetUniqueName("_");
{
// delcare type.
this->PrintIndent();
this->PrintType(t, stream);
stream << ' ' << sret << ";\n";
}
for (int i = 0; i < lanes; ++i) {
std::ostringstream value_temp;
if (isalpha(op[0])) {
value_temp << op << "(";
PrintVecElemLoad(vlhs, lhs.type(), i, value_temp);
value_temp << ", ";
PrintVecElemLoad(vrhs, rhs.type(), i, value_temp);
value_temp << ")";
} else {
value_temp << "(";
PrintVecElemLoad(vlhs, lhs.type(), i, value_temp);
value_temp << op;
PrintVecElemLoad(vrhs, rhs.type(), i, value_temp);
value_temp << ")";
}
PrintVecElemStore(sret, t, i, value_temp.str());
}
os << sret;
}
}
void CodeGenCUDA::PrintVecElemLoad(
const std::string& vec, Type t, int i, std::ostream& os) { // NOLINT(*)
const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < 4);
os << vec << "." << access[i];
}
void CodeGenCUDA::PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value) {
this->PrintIndent();
const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < 4);
stream << vec << "." << access[i] << " = " << value << ";\n";
}
void CodeGenCUDA::PrintStorageSync(const std::string& sync) { void CodeGenCUDA::PrintStorageSync(const std::string& sync) {
if (sync == "shared") { if (sync == "shared") {
this->PrintIndent(); this->PrintIndent();
...@@ -43,8 +145,6 @@ void CodeGenCUDA::PrintStorageScope( ...@@ -43,8 +145,6 @@ void CodeGenCUDA::PrintStorageScope(
std::unordered_map<LoweredFunc, PackedFunc> std::unordered_map<LoweredFunc, PackedFunc>
MakeNVRTC(Array<LoweredFunc> funcs) { MakeNVRTC(Array<LoweredFunc> funcs) {
std::ostringstream os; std::ostringstream os;
os << "typedef int int32_t;\n"
<< "typedef unsigned unt32_t;\n";
bool output_ssa = false; bool output_ssa = false;
for (LoweredFunc f : funcs) { for (LoweredFunc f : funcs) {
os << CodeGenCUDA().Compile(f, output_ssa); os << CodeGenCUDA().Compile(f, output_ssa);
...@@ -56,6 +156,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) { ...@@ -56,6 +156,7 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc"); const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc");
code = f(code).operator std::string(); code = f(code).operator std::string();
} }
LOG(INFO) << code;
std::string ptx; std::string ptx;
if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) { if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile"); const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile");
......
...@@ -25,9 +25,18 @@ class CodeGenCUDA : public CodeGenC { ...@@ -25,9 +25,18 @@ class CodeGenCUDA : public CodeGenC {
*/ */
std::string Compile(LoweredFunc f, std::string Compile(LoweredFunc f,
bool output_ssa); bool output_ssa);
// override behavior // override behavior
void PrintStorageSync(const std::string& sync) final; void PrintStorageSync(const std::string& sync) final;
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintVecBinaryOp(
const std::string&op, Type t,
Expr lhs, Expr rhs, std::ostream& os) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
void PrintVecElemLoad(
const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*)
void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value) final;
}; };
} // namespace codegen } // namespace codegen
......
...@@ -19,7 +19,11 @@ std::string CodeGenOpenCL::Compile( ...@@ -19,7 +19,11 @@ std::string CodeGenOpenCL::Compile(
LoweredFunc f, LoweredFunc f,
bool output_ssa) { bool output_ssa) {
this->stream << " __kernel "; this->stream << " __kernel ";
this->arg_addr_space_ = "__global "; for (Var arg : f->args) {
if (arg.type().is_handle()) {
alloc_storage_scope_[arg.get()] = "global";
}
}
return CodeGenC::Compile(f, output_ssa); return CodeGenC::Compile(f, output_ssa);
} }
...@@ -34,6 +38,80 @@ void CodeGenOpenCL::PrintThreadIndexExpr( ...@@ -34,6 +38,80 @@ void CodeGenOpenCL::PrintThreadIndexExpr(
} }
} }
void CodeGenOpenCL::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
CHECK_EQ(lanes, 1)
<< "do not yet support vector types";
os << "void*"; return;
}
bool fail = false;
if (t.is_float()) {
switch (t.bits()) {
case 16: os << "half"; break;
case 32: os << "float"; break;
case 64: os << "double"; break;
default: fail = true; break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
os << lanes; return;
}
} else if (t.is_uint() || t.is_int()) {
if (t.is_uint()) {
os << 'u';
}
if (t.bits() == 8 && t.lanes() == 4) {
// directly 4 8 bit int in integer.
os << "int"; return;
}
switch (t.bits()) {
case 8: os << "char"; break;
case 16: os << "short"; break;
case 32: os << "int"; break;
case 64: os << "long"; break;
case 1: os << "int"; break;
default: fail = true; break;
}
if (!fail && lanes == 1) return;
if (!fail && (lanes >= 2 && lanes <= 16)) {
os << lanes; return;
}
}
LOG(FATAL) << "Cannot convert type " << t << " to OpenCL type";
}
void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, Type t,
Expr base, std::ostream& os) { // NOLINT(*)
if (!HandleTypeMatch(buffer, t.element_of())) {
os << '(';
auto it = alloc_storage_scope_.find(buffer);
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, os);
}
os << ' ';
PrintType(t.element_of(), os);
os << "*)";
}
os << GetVarID(buffer) << " + ";
PrintExpr(base, os);
}
void CodeGenOpenCL::PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os) {
os << "vload" << t.lanes() << "(0, ";
PrintVecAddr(buffer, t, base, os);
os << ")";
}
void CodeGenOpenCL::PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value) {
this->PrintIndent();
stream << "vstore" << t.lanes() << "(" << value << ", 0, ";
PrintVecAddr(buffer, t, base, stream);
stream << ");\n";
}
void CodeGenOpenCL::PrintStorageSync(const std::string& sync) { void CodeGenOpenCL::PrintStorageSync(const std::string& sync) {
if (sync == "shared") { if (sync == "shared") {
...@@ -45,8 +123,9 @@ void CodeGenOpenCL::PrintStorageSync(const std::string& sync) { ...@@ -45,8 +123,9 @@ void CodeGenOpenCL::PrintStorageSync(const std::string& sync) {
} }
void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*) void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
CHECK_NE(scope, "global"); if (scope == "global") {
if (scope == "shared") { os << "__global";
} else if (scope == "shared") {
os << "__local "; os << "__local ";
} }
} }
...@@ -55,8 +134,6 @@ void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os ...@@ -55,8 +134,6 @@ void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os
std::unordered_map<LoweredFunc, PackedFunc> std::unordered_map<LoweredFunc, PackedFunc>
MakeOpenCL(Array<LoweredFunc> funcs) { MakeOpenCL(Array<LoweredFunc> funcs) {
std::ostringstream os; std::ostringstream os;
os << "typedef int int32_t;\n"
<< "typedef unsigned unt32_t;\n";
bool output_ssa = false; bool output_ssa = false;
for (LoweredFunc f : funcs) { for (LoweredFunc f : funcs) {
os << CodeGenOpenCL().Compile(f, output_ssa); os << CodeGenOpenCL().Compile(f, output_ssa);
......
...@@ -30,6 +30,16 @@ class CodeGenOpenCL : public CodeGenC { ...@@ -30,6 +30,16 @@ class CodeGenOpenCL : public CodeGenC {
std::string tag, std::ostream& os) final; // NOLINT(*) std::string tag, std::ostream& os) final; // NOLINT(*)
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const std::string& scope) final; // NOLINT(*) void PrintStorageSync(const std::string& scope) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
void PrintVecLoad(const Variable* buffer,
Type t, Expr base,
std::ostream& os) final; // NOLINT(*)
void PrintVecStore(const Variable* buffer,
Type t, Expr base,
const std::string& value) final; // NOLINT(*)
// the address of load/store
void PrintVecAddr(const Variable* buffer, Type t,
Expr base, std::ostream& os); // NOLINT(*)
}; };
} // namespace codegen } // namespace codegen
......
...@@ -74,6 +74,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -74,6 +74,7 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
.DISPATCH_TO_MUTATE_STMT(Provide) .DISPATCH_TO_MUTATE_STMT(Provide)
.DISPATCH_TO_MUTATE_STMT(Realize) .DISPATCH_TO_MUTATE_STMT(Realize)
.DISPATCH_TO_MUTATE_STMT(Store) .DISPATCH_TO_MUTATE_STMT(Store)
.DISPATCH_TO_MUTATE_STMT(IfThenElse)
.DISPATCH_TO_MUTATE_STMT(For) .DISPATCH_TO_MUTATE_STMT(For)
.DISPATCH_TO_MUTATE_STMT(Allocate) .DISPATCH_TO_MUTATE_STMT(Allocate)
.DISPATCH_TO_MUTATE_STMT(Free); .DISPATCH_TO_MUTATE_STMT(Free);
...@@ -195,6 +196,22 @@ Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) { ...@@ -195,6 +196,22 @@ Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) {
return s; return s;
} }
Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
Expr condition = this->Mutate(op->condition);
Stmt then_case = this->Mutate(op->then_case);
Stmt else_case;
if (else_case.defined()) {
else_case = this->Mutate(op->else_case);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
}
TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.DISPATCH_TO_MUTATE_EXPR(Call) .DISPATCH_TO_MUTATE_EXPR(Call)
.DISPATCH_TO_MUTATE_EXPR(Let) .DISPATCH_TO_MUTATE_EXPR(Let)
...@@ -363,21 +380,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) ...@@ -363,21 +380,6 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
return Block::make(first, rest); return Block::make(first, rest);
} }
}) })
.set_dispatch<IfThenElse>([](const IfThenElse *op, const Stmt& s, IRMutator* m) {
Expr condition = m->Mutate(op->condition);
Stmt then_case = m->Mutate(op->then_case);
Stmt else_case;
if (else_case.defined()) {
else_case = m->Mutate(op->else_case);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
})
.set_dispatch<Evaluate>([](const Evaluate *op, const Stmt& s, IRMutator* m) { .set_dispatch<Evaluate>([](const Evaluate *op, const Stmt& s, IRMutator* m) {
Expr v = m->Mutate(op->value); Expr v = m->Mutate(op->value);
if (v.same_as(op->value)) { if (v.same_as(op->value)) {
......
...@@ -101,9 +101,14 @@ class IRUseDefAnalysis : public IRMutator { ...@@ -101,9 +101,14 @@ class IRUseDefAnalysis : public IRMutator {
} }
void HandleDef(const Variable* v) { void HandleDef(const Variable* v) {
CHECK(!def_count_.count(v))
<< "variable " << v->name_hint
<< " has already been defined, the Stmt is not SSA";
CHECK(!use_count_.count(v)) CHECK(!use_count_.count(v))
<< "variable is already defined"; << "variable " << v->name_hint
<< " has been used before definition!";
use_count_[v] = 0; use_count_[v] = 0;
def_count_[v] = 1;
} }
void HandleUse(const Expr& v) { void HandleUse(const Expr& v) {
...@@ -127,6 +132,7 @@ class IRUseDefAnalysis : public IRMutator { ...@@ -127,6 +132,7 @@ class IRUseDefAnalysis : public IRMutator {
Array<IterVar> thread_axis_; Array<IterVar> thread_axis_;
Array<Expr> thread_extent_; Array<Expr> thread_extent_;
std::unordered_map<const Variable*, int> use_count_; std::unordered_map<const Variable*, int> use_count_;
std::unordered_map<const Variable*, int> def_count_;
}; };
class HostDeviceSplitter : public IRMutator { class HostDeviceSplitter : public IRMutator {
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2017 by Contributors
* SSA related checks and pass. * Loop unrolling.
* \file ssa.cc * \file unroll_loop.cc
*/ */
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include <unordered_set> #include <unordered_set>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "../arithmetic//compute_expr.h" #include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
...@@ -33,7 +33,8 @@ class LoopUnroller : public IRMutator { ...@@ -33,7 +33,8 @@ class LoopUnroller : public IRMutator {
if (v2 != nullptr) { if (v2 != nullptr) {
value = static_cast<int>(v2->value); value = static_cast<int>(v2->value);
} }
bool allow_unroll = value >= 0 && value <= max_auto_step_; bool allow_unroll = (op->for_type == ForType::Serial &&
value >= 0 && value <= max_auto_step_);
if (op->for_type == ForType::Unrolled) { if (op->for_type == ForType::Unrolled) {
CHECK_GE(value, 0) CHECK_GE(value, 0)
<< "Cannot unroll non-constant loop"; << "Cannot unroll non-constant loop";
......
/*!
* Copyright (c) 2017 by Contributors
* Vectorize the loop
* \file vectorize_loop.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace ir {
inline Expr BroadcastTo(Expr e, int lanes) {
if (e.type().lanes() == lanes) return e;
CHECK_EQ(e.type().lanes(), 1)
<< "Cannot broadcast lane=" << e.type().lanes()
<< " to " << lanes;
return Broadcast::make(e, lanes);
}
// Rewrite vectorized allocation access
// s[i] = s[i * lanes + var]
class VecAllocAccess : public IRMutator {
public:
VecAllocAccess(const Variable* buf, Var var, int var_lanes)
: buf_(buf), var_(var), var_lanes_(var_lanes) {}
// Load
Expr Mutate_(const Load* op, const Expr& e) final {
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Load>();
if (op->buffer_var.get() == buf_) {
return Load::make(op->type, op->buffer_var,
op->index * var_lanes_ + var_);
} else {
return expr;
}
}
// Store
Stmt Mutate_(const Store* op, const Stmt& s) final {
Stmt stmt = IRMutator::Mutate_(op, s);
op = stmt.as<Store>();
if (op->buffer_var.get() == buf_) {
return Store::make(op->buffer_var,
op->value,
op->index * var_lanes_ + var_);
} else {
return stmt;
}
}
private:
// buffer var
const Variable* buf_;
// variable to be replaced
Var var_;
// the lanes.
int var_lanes_;
};
class Vectorizer : public IRMutator {
public:
Vectorizer(Var var, int var_lanes)
: var_(var), var_lanes_(var_lanes) {
ramp_ = Ramp::make(0, 1, var_lanes);
}
// user mutate from parent.
using IRMutator::Mutate;
// override mutate
Expr Mutate(Expr expr) final {
static const FMutateExpr& f = Vectorizer::vtable_expr();
return (f.can_dispatch(expr) ?
f(expr, expr, this) : IRMutator::Mutate(expr));
}
// Variable
Expr Mutate_(const Variable* v, const Expr& e) final {
if (v == var_.get()) {
return ramp_;
} else if (lets_.count(v)) {
return lets_[v];
} else {
return e;
}
}
// Call
Expr Mutate_(const Call* op, const Expr& e) final {
int lane = 0;
Array<Expr> new_args = MutateArray(op->args, &lane);
if (op->args.same_as(new_args)) {
return e;
} else {
return Call::make(
op->type.with_lanes(lane), op->name, new_args,
op->call_type, op->func, op->value_index);
}
}
// Load
Expr Mutate_(const Load* op, const Expr& e) final {
Expr index = this->Mutate(op->index);
if (index.same_as(op->index)) {
return e;
} else {
return Load::make(op->type.with_lanes(index.type().lanes()),
op->buffer_var, index);
}
}
// Let
Expr Mutate_(const Let* op, const Expr& e) final {
Expr value = this->Mutate(op->value);
CHECK(!lets_.count(op->var.get())) << "not SSA";
if (value.type().lanes() != op->value.type().lanes()) {
Var v(op->var->name_hint, value.type());
lets_[op->var.get()] = v;
return Let::make(v, value, Mutate(op->body));
} else {
Expr body = this->Mutate(op->body);
if (value.same_as(op->value) &&
body.same_as(op->body)) {
return e;
} else {
return Let::make(op->var, value, body);
}
}
}
// Provide
Stmt Mutate_(const Provide* op, const Stmt& s) final {
Expr new_value = this->Mutate(op->value);
int lane = new_value.type().lanes();
Array<Expr> new_args = MutateArray(op->args, &lane);
if (op->args.same_as(new_args) && op->value.same_as(new_value)) {
return s;
} else {
new_value = BroadcastTo(new_value, lane);
return Provide::make(op->func, op->value_index, new_value, new_args);
}
}
// Store
Stmt Mutate_(const Store* op, const Stmt& s) final {
Expr value = this->Mutate(op->value);
Expr index = this->Mutate(op->index);
if (value.same_as(op->value) && index.same_as(op->index)) {
return s;
} else {
int lanes = std::max(value.type().lanes(), index.type().lanes());
return Store::make(op->buffer_var,
BroadcastTo(value, lanes),
BroadcastTo(index, lanes));
}
}
// For
Stmt Mutate_(const For* op, const Stmt& s) final {
if (op->for_type == ForType::Vectorized) {
LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring...";
}
CHECK(is_zero(op->min));
CHECK(!op->extent.type().is_vector());
Expr extent = Mutate(op->extent);
if (extent.type().is_vector()) {
LOG(WARNING) << "Detect vectorized extent type, scalarizing...";
return Scalarize(s);
}
Stmt body = Mutate(op->body);
if (extent.same_as(op->extent) &&
body.same_as(op->body)) {
return s;
} else {
return For::make(
op->loop_var, op->min, extent,
op->for_type, op->device_api, body);
}
}
// IfThenElse
Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
CHECK(!op->condition.type().is_vector());
Expr condition = this->Mutate(op->condition);
if (condition.type().is_vector()) {
LOG(WARNING) << "Detect vector condition in Vectorized Loop, scalarizing...";
return Scalarize(s);
}
Stmt then_case = this->Mutate(op->then_case);
Stmt else_case;
if (else_case.defined()) {
else_case = this->Mutate(op->else_case);
}
if (condition.same_as(op->condition) &&
then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return s;
} else {
return IfThenElse::make(condition, then_case, else_case);
}
}
// LetStmt
Stmt Mutate_(const LetStmt* op, const Stmt& s) final {
LOG(WARNING) << "Cannot vectorize with LetStmt, remove it with Simplify Before Vectorize";
return Scalarize(s);
}
// Allocate
Stmt Mutate_(const Allocate* op, const Stmt& s) final {
if (op->new_expr.defined()) {
LOG(WARNING) << "Cannot vectorize with new expr";
return Scalarize(s);
}
Expr condition = Mutate(op->condition);
if (condition.type().is_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc ";
return Scalarize(s);
}
Array<Expr> extents;
for (size_t i = 0; i < op->extents.size(); i++) {
Expr new_ext = Mutate(op->extents[i]);
if (new_ext.type().is_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc ";
return Scalarize(s);
}
extents.push_back(new_ext);
}
// place the vector lanes in least significant dimension.
extents.push_back(var_lanes_);
// rewrite access to buffer internally.
Stmt body = VecAllocAccess(
op->buffer_var.get(), var_, var_lanes_).Mutate(op->body);
body = Mutate(body);
return Allocate::make(
op->buffer_var, op->type,
extents, condition, body,
op->new_expr, op->free_function);
}
// scalarize the statment
Stmt Scalarize(Stmt stmt) {
Var idx(var_->name_hint + ".s", var_->type);
stmt = Substitute(stmt, {{var_, idx}});
return For::make(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
}
// The overloads for vectorize.
static FMutateExpr& vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
private:
// variable to be replaced
Var var_;
// the lanes.
int var_lanes_;
// ramp representing the var.
Expr ramp_;
// The lets
std::unordered_map<const Variable*, Expr> lets_;
// mutate array, with given lane requirement
// when finished, p_lane updates the lane requirement.
Array<Expr> MutateArray(Array<Expr> arr, int* p_lanes) {
if (arr.size() == 0) return arr;
int& lanes = *p_lanes;
bool changed = false;
std::vector<Expr> new_arr(arr.size());
for (size_t i = 0; i < arr.size(); i++) {
Expr old_elem = arr[i];
Expr new_elem = this->Mutate(old_elem);
if (!new_elem.same_as(old_elem)) changed = true;
new_arr[i] = new_elem;
lanes = std::max(lanes, new_elem.type().lanes());
}
for (size_t i = 0; i < arr.size(); ++i) {
if (new_arr[i].type().lanes() != lanes) {
new_arr[i] = BroadcastTo(new_arr[i], lanes);
changed = true;
}
}
if (!changed) return arr;
return Array<Expr>(new_arr);
}
};
// binary vectorize
template<typename T>
inline Expr BinaryVec(const T* op, const Expr& e, IRMutator* m) {
Expr a = m->Mutate(op->a);
Expr b = m->Mutate(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e;
} else {
int lanes = std::max(a.type().lanes(), b.type().lanes());
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
}
template<typename T>
inline Expr AddSubVec(const T* op, const Expr& e, IRMutator* m) {
Expr a = m->Mutate(op->a);
Expr b = m->Mutate(op->b);
if (a.same_as(op->a) &&
b.same_as(op->b)) {
return e;
} else {
int lanes = std::max(a.type().lanes(), b.type().lanes());
if (lanes != 1) {
const Ramp* b_ramp = b.as<Ramp>();
const Ramp* a_ramp = a.as<Ramp>();
if (a.type().lanes() == 1 && b_ramp) {
return Ramp::make(
arith::ComputeExpr<T>(a, b_ramp->base), b_ramp->stride, b_ramp->lanes);
}
if (b.type().lanes() == 1 && a_ramp) {
return Ramp::make(
arith::ComputeExpr<T>(a_ramp->base, b), a_ramp->stride, a_ramp->lanes);
}
}
return T::make(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
}
TVM_STATIC_IR_FUNCTOR(Vectorizer, vtable_expr)
.set_dispatch<Add>(AddSubVec<Add>)
.set_dispatch<Sub>(AddSubVec<Sub>)
.set_dispatch<Mul>(BinaryVec<Mul>)
.set_dispatch<Div>(BinaryVec<Div>)
.set_dispatch<Mod>(BinaryVec<Mod>)
.set_dispatch<Min>(BinaryVec<Min>)
.set_dispatch<Max>(BinaryVec<Max>)
.set_dispatch<EQ>(BinaryVec<EQ>)
.set_dispatch<NE>(BinaryVec<NE>)
.set_dispatch<LT>(BinaryVec<LT>)
.set_dispatch<LE>(BinaryVec<LE>)
.set_dispatch<GT>(BinaryVec<GT>)
.set_dispatch<GE>(BinaryVec<GE>)
.set_dispatch<And>(BinaryVec<And>)
.set_dispatch<Or>(BinaryVec<Or>);
TVM_STATIC_IR_FUNCTOR(Vectorizer, vtable_expr)
.set_dispatch<Select>([](const Select *op, const Expr& e, IRMutator* m) {
Expr cond = m->Mutate(op->condition);
Expr t = m->Mutate(op->true_value);
Expr f = m->Mutate(op->false_value);
if (cond.same_as(op->condition) &&
t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return e;
} else {
int lanes = std::max(std::max(
cond.type().lanes(),
t.type().lanes()), f.type().lanes());
return Select::make(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
}
})
.set_dispatch<Cast>([](const Cast *op, const Expr& e, IRMutator* m) {
Expr value = m->Mutate(op->value);
if (value.same_as(op->value)) {
return e;
} else {
return Cast::make(op->type.with_lanes(value.type().lanes()), value);
}
});
class LoopVectorizer : public IRMutator {
public:
Stmt Mutate_(const For* op, const Stmt& s) final {
if (op->for_type == ForType::Vectorized) {
CHECK(is_zero(op->min));
CHECK(is_positive_const(op->extent));
int lanes = 0;
bool succ = arith::GetConstInt(op->extent, &lanes);
if (!succ || lanes < 1) {
LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent;
}
Var var(op->loop_var.node_);
return Vectorizer(var, lanes).Mutate(op->body);
} else {
return IRMutator::Mutate_(op, s);
}
}
};
Stmt VectorizeLoop(Stmt stmt) {
return LoopVectorizer().Mutate(stmt);
}
} // namespace ir
} // namespace tvm
...@@ -57,6 +57,14 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -57,6 +57,14 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
<< ")"; << ")";
}); });
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IterVarAttrNode>([](const IterVarAttrNode *op, IRPrinter *p) {
switch (op->iter_type) {
case kUnrolled: p->stream << "unroll"; break;
case kVectorized: p->stream << "vectorize"; break;
}
});
Stage::Stage(Operation op) { Stage::Stage(Operation op) {
auto n = std::make_shared<StageNode>(); auto n = std::make_shared<StageNode>();
n->op = op; n->op = op;
...@@ -246,7 +254,38 @@ void Schedule::normalize() { ...@@ -246,7 +254,38 @@ void Schedule::normalize() {
} }
} }
IterVarAttr::IterVarAttr(IterVarType t) {
std::shared_ptr<IterVarAttrNode> n = std::make_shared<IterVarAttrNode>();
n->iter_type = t;
node_ = n;
}
inline void SetAttr(StageNode* self, IterVar var, IterVarAttr attr) {
ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
FindLeafVar(all_vars, leaf_vars, var);
auto it = self->iter_var_attrs.find(var);
if (it != self->iter_var_attrs.end()) {
CHECK_EQ((*it).second->iter_type, attr->iter_type)
<< "IterVar's is already set to "
<< (*it).second << " instead of " << attr;
} else {
self->iter_var_attrs.Set(var, attr);
}
}
Stage& Stage::vectorize(IterVar var) { // NOLINT(*)
SetAttr(operator->(), var, IterVarAttr(kVectorized));
return *this;
}
Stage& Stage::unroll(IterVar var) { // NOLINT(*)
SetAttr(operator->(), var, IterVarAttr(kUnrolled));
return *this;
}
TVM_REGISTER_NODE_TYPE(StageNode); TVM_REGISTER_NODE_TYPE(StageNode);
TVM_REGISTER_NODE_TYPE(IterVarAttrNode);
TVM_REGISTER_NODE_TYPE(SplitNode); TVM_REGISTER_NODE_TYPE(SplitNode);
TVM_REGISTER_NODE_TYPE(FuseNode); TVM_REGISTER_NODE_TYPE(FuseNode);
TVM_REGISTER_NODE_TYPE(RebaseNode); TVM_REGISTER_NODE_TYPE(RebaseNode);
......
...@@ -177,6 +177,13 @@ MakeLoopNest(const Stage& sch, ...@@ -177,6 +177,13 @@ MakeLoopNest(const Stage& sch,
} }
// Mark the iter var in the IR, to remember the point // Mark the iter var in the IR, to remember the point
if (iv->thread_tag.length() == 0) { if (iv->thread_tag.length() == 0) {
ForType for_type = ForType::Serial;
if (sch->iter_var_attrs.count(iv)) {
switch (sch->iter_var_attrs[iv]->iter_type) {
case kUnrolled: for_type = ForType::Unrolled; break;
case kVectorized: for_type = ForType::Vectorized; break;
}
}
if (is_one(dom->extent)) { if (is_one(dom->extent)) {
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
LetStmt::make(var, dom->min, no_op)); LetStmt::make(var, dom->min, no_op));
...@@ -184,13 +191,13 @@ MakeLoopNest(const Stage& sch, ...@@ -184,13 +191,13 @@ MakeLoopNest(const Stage& sch,
} else if (is_zero(dom->min)) { } else if (is_zero(dom->min)) {
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
For::make(var, 0, dom->extent, For::make(var, 0, dom->extent,
ForType::Serial, DeviceAPI::None, no_op)); for_type, DeviceAPI::None, no_op));
value_map[iv] = var; value_map[iv] = var;
} else { } else {
Var idx(iv->var->name_hint + ".idx", iv->var.type()); Var idx(iv->var->name_hint + ".idx", iv->var.type());
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
For::make(idx, 0, dom->extent, For::make(idx, 0, dom->extent,
ForType::Serial, DeviceAPI::None, no_op)); for_type, DeviceAPI::None, no_op));
Expr new_value = dom->min + idx; Expr new_value = dom->min + idx;
value_map[iv] = new_value; value_map[iv] = new_value;
nest[i + 1].emplace_back( nest[i + 1].emplace_back(
......
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
def test_add(): def test_add():
# graph # graph
n = tvm.Var('n') n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
...@@ -13,26 +13,28 @@ def test_add(): ...@@ -13,26 +13,28 @@ def test_add():
num_thread = 256 num_thread = 256
block_x = tvm.IterVar(thread_tag="blockIdx.x") block_x = tvm.IterVar(thread_tag="blockIdx.x")
thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x") thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
_, x = s[C].split(C.op.axis[0], factor=num_thread, outer=block_x) _, x = s[C].split(C.op.axis[0], factor=num_thread*4, outer=block_x)
_, x = s[C].split(x, outer=thread_x) _, x = s[C].split(x, outer=thread_x)
_, x = s[C].split(x, factor=4)
s[C].vectorize(x)
# one line to build the function. # one line to build the function.
codes = [] def check_device(target):
fadd = tvm.build(s, codes = []
args=[A, B, C], fadd = tvm.build(s, [A, B, C],
target="cuda", name="myadd", target, record_codes=codes,
record_codes=codes) name="myadd")
for c in codes: if target == "cuda":
print(c) ctx = tvm.gpu(0)
else:
# call the function ctx = tvm.cl(0)
num_device = 1
for i in range(num_device):
ctx = tvm.gpu(i)
if not ctx.enabled: if not ctx.enabled:
continue return
for c in codes[1:]:
print(c)
# launch the kernel. # launch the kernel.
n = 1027 n = 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
...@@ -40,6 +42,10 @@ def test_add(): ...@@ -40,6 +42,10 @@ def test_add():
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy()) c.asnumpy(), a.asnumpy() + b.asnumpy())
tvm.init_opencl()
check_device("cuda")
check_device("opencl")
if __name__ == "__main__": if __name__ == "__main__":
test_add() test_add()
...@@ -76,6 +76,21 @@ def test_fuse(): ...@@ -76,6 +76,21 @@ def test_fuse():
assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations) assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations)
assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi) assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi)
def test_vectorize():
m = tvm.Var('m')
n = tvm.Var('n')
A = tvm.placeholder((m, n), name='A')
T = tvm.compute((m, n), lambda i, j: A[i, j])
s = tvm.Schedule(T.op)
xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5)
s[T].vectorize(yi)
s[T].unroll(xi)
UNROLL = 1
VECTORIZE = 2
assert s[T].iter_var_attrs[xi].iter_type == UNROLL
assert s[T].iter_var_attrs[yi].iter_type == VECTORIZE
if __name__ == "__main__": if __name__ == "__main__":
test_schedule_create() test_schedule_create()
...@@ -83,3 +98,4 @@ if __name__ == "__main__": ...@@ -83,3 +98,4 @@ if __name__ == "__main__":
test_tile() test_tile()
test_split() test_split()
test_fuse() test_fuse()
test_vectorize()
...@@ -9,11 +9,13 @@ def test_unroll_loop(): ...@@ -9,11 +9,13 @@ def test_unroll_loop():
# for i in 0 to n-1: # for i in 0 to n-1:
stmt = tvm.make.For( stmt = tvm.make.For(
i, n, 2, 0, 0, i, n, 2, 0, 0,
tvm.make.For(j, 0, n, 0, 0, tvm.make.For(j, 0, 8, 3, 0,
tvm.make.Store(Ab.data, tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1, tvm.make.Load(dtype, Ab.data, i) + 1,
j + 1))) j + 1)))
stmt = tvm.ir_pass.UnrollLoop(stmt, 8) assert isinstance(stmt, tvm.stmt.For)
stmt = tvm.ir_pass.UnrollLoop(stmt, 4)
assert not isinstance(stmt, tvm.stmt.For)
print(stmt) print(stmt)
if __name__ == "__main__": if __name__ == "__main__":
......
import tvm
def test_vectorize_loop():
dtype = 'int64'
n = tvm.Var('n')
Ab = tvm.Buffer((n, ), dtype)
i = tvm.Var('i')
j = tvm.Var('j')
VECTORIZE = 2
# for i in 0 to n-1:
stmt = tvm.make.For(
i, n, 2, 0, 0,
tvm.make.For(j, 0, 4, VECTORIZE, 0,
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1,
j + 1)))
assert isinstance(stmt.body, tvm.stmt.For)
stmt = tvm.ir_pass.VectorizeLoop(stmt)
assert isinstance(stmt, tvm.stmt.For)
assert not isinstance(stmt.body, tvm.stmt.For)
print(stmt)
if __name__ == "__main__":
test_vectorize_loop()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment