Commit 5d37be62 by Lianmin Zheng Committed by Tianqi Chen

[CODEGEN] update codegen for vector operation (#711)

* [CODEGEN] update codegen for vector operation

* update comment, fix for metal
parent 89b8456e
...@@ -567,6 +567,10 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) ...@@ -567,6 +567,10 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
std::string ref = GetVecLoad(op->type, op->buffer_var.get(), base); std::string ref = GetVecLoad(op->type, op->buffer_var.get(), base);
os << ref; os << ref;
} else { } else {
// The assignment below introduces side-effect, and the resulting value cannot
// be reused across multiple expression, thus a new scope is needed
int vec_scope = BeginScope();
// load seperately. // load seperately.
std::string svalue = GetUniqueName("_"); std::string svalue = GetUniqueName("_");
this->PrintIndent(); this->PrintIndent();
...@@ -590,6 +594,7 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) ...@@ -590,6 +594,7 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
PrintVecElemStore(svalue, op->type, i, value_temp.str()); PrintVecElemStore(svalue, op->type, i, value_temp.str());
} }
os << svalue; os << svalue;
EndScope(vec_scope);
} }
} }
} }
...@@ -609,6 +614,10 @@ void CodeGenC::VisitStmt_(const Store* op) { ...@@ -609,6 +614,10 @@ void CodeGenC::VisitStmt_(const Store* op) {
std::string value = this->PrintExpr(op->value); std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer_var.get(), t, base, value); this->PrintVecStore(op->buffer_var.get(), t, base, value);
} else { } else {
// The assignment below introduces side-effect, and the resulting value cannot
// be reused across multiple expression, thus a new scope is needed
int vec_scope = BeginScope();
// store elements seperately // store elements seperately
std::string index = SSAGetID(PrintExpr(op->index), op->index.type()); std::string index = SSAGetID(PrintExpr(op->index), op->index.type());
std::string value = SSAGetID(PrintExpr(op->value), op->value.type()); std::string value = SSAGetID(PrintExpr(op->value), op->value.type());
...@@ -629,6 +638,7 @@ void CodeGenC::VisitStmt_(const Store* op) { ...@@ -629,6 +638,7 @@ void CodeGenC::VisitStmt_(const Store* op) {
PrintVecElemLoad(value, op->value.type(), i, stream); PrintVecElemLoad(value, op->value.type(), i, stream);
stream << ";\n"; stream << ";\n";
} }
EndScope(vec_scope);
} }
} }
} }
...@@ -642,7 +652,13 @@ void CodeGenC::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*) ...@@ -642,7 +652,13 @@ void CodeGenC::VisitExpr_(const Let* op, std::ostream& os) { // NOLINT(*)
} }
void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*) void CodeGenC::VisitExpr_(const Ramp* op, std::ostream& os) { // NOLINT(*)
LOG(FATAL) << "Ramp: not supported "; os << "((int" << op->lanes << ")(";
for (int i = 0; i < op->lanes; i++) {
os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")";
if (i != op->lanes - 1)
os << ", ";
}
os << "))";
} }
void CodeGenC::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) void CodeGenC::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
......
...@@ -120,6 +120,10 @@ void CodeGenCUDA::PrintVecBinaryOp( ...@@ -120,6 +120,10 @@ void CodeGenCUDA::PrintVecBinaryOp(
int lanes = t.lanes(); int lanes = t.lanes();
{ {
// The assignment below introduces side-effect, and the resulting value cannot
// be reused across multiple expression, thus a new scope is needed
int vec_scope = BeginScope();
// default: unpack into individual ops. // default: unpack into individual ops.
std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.type()); std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.type());
std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.type()); std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.type());
...@@ -148,6 +152,7 @@ void CodeGenCUDA::PrintVecBinaryOp( ...@@ -148,6 +152,7 @@ void CodeGenCUDA::PrintVecBinaryOp(
PrintVecElemStore(sret, t, i, value_temp.str()); PrintVecElemStore(sret, t, i, value_temp.str());
} }
os << sret; os << sret;
EndScope(vec_scope);
} }
} }
...@@ -232,6 +237,16 @@ void CodeGenCUDA::VisitStmt_(const Evaluate *op) { ...@@ -232,6 +237,16 @@ void CodeGenCUDA::VisitStmt_(const Evaluate *op) {
} }
} }
void CodeGenCUDA::VisitExpr_(const Ramp* op, std::ostream& os) {
os << "((make_int" << op->lanes << ")(";
for (int i = 0; i < op->lanes; i++) {
os << "(" << PrintExpr(op->base) << ")" << "+(" << PrintExpr(op->stride) << "*" << i <<")";
if (i != op->lanes - 1)
os << ", ";
}
os << "))";
}
void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) void CodeGenCUDA::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value); std::string v = PrintExpr(op->value);
os << "make_"; os << "make_";
......
...@@ -33,6 +33,7 @@ class CodeGenCUDA final : public CodeGenC { ...@@ -33,6 +33,7 @@ class CodeGenCUDA final : public CodeGenC {
const std::string& vec, Type t, int i, const std::string& value) final; const std::string& vec, Type t, int i, const std::string& value) final;
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// overload visitor // overload visitor
void VisitExpr_(const Ramp* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
void VisitStmt_(const Evaluate *op) final; void VisitStmt_(const Evaluate *op) final;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment