Unverified Commit 952def53 by Wei Pan Committed by GitHub

[CodeGen] Cleanup generated code (#5424)

- remove unnecessary white spaces from storage kind
- do not start a new scope for vectorization as temporary
  variables are alll uniquely generated.

The above two changes make vectorized code much cleaner.

Signed-off-by: Wei Pan <weip@nvidia.com>
parent b637840b
...@@ -94,7 +94,6 @@ void CodeGenC::AddFunction(const PrimFunc& f) { ...@@ -94,7 +94,6 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
auto it = alloc_storage_scope_.find(v.get()); auto it = alloc_storage_scope_.find(v.get());
if (it != alloc_storage_scope_.end()) { if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream); PrintStorageScope(it->second, stream);
stream << ' ';
} }
PrintType(GetType(v), stream); PrintType(GetType(v), stream);
...@@ -179,7 +178,6 @@ std::string CodeGenC::GetBufferRef( ...@@ -179,7 +178,6 @@ std::string CodeGenC::GetBufferRef(
if (!scope.empty() && IsScopePartOfType()) { if (!scope.empty() && IsScopePartOfType()) {
PrintStorageScope(scope, os); PrintStorageScope(scope, os);
} }
os << ' ';
PrintType(t, os); PrintType(t, os);
os << "*)" << vid << ')'; os << "*)" << vid << ')';
} else { } else {
...@@ -213,7 +211,6 @@ std::string CodeGenC::GetBufferRef( ...@@ -213,7 +211,6 @@ std::string CodeGenC::GetBufferRef(
if (!scope.empty() && IsScopePartOfType()) { if (!scope.empty() && IsScopePartOfType()) {
PrintStorageScope(scope, os); PrintStorageScope(scope, os);
} }
os << ' ';
PrintType(t, os); PrintType(t, os);
os << "*)("; os << "*)(";
if (!HandleTypeMatch(buffer, t.element_of())) { if (!HandleTypeMatch(buffer, t.element_of())) {
...@@ -221,7 +218,6 @@ std::string CodeGenC::GetBufferRef( ...@@ -221,7 +218,6 @@ std::string CodeGenC::GetBufferRef(
if (!scope.empty() && IsScopePartOfType()) { if (!scope.empty() && IsScopePartOfType()) {
PrintStorageScope(scope, os); PrintStorageScope(scope, os);
} }
os << ' ';
PrintType(t.element_of(), os); PrintType(t.element_of(), os);
os << "*)"; os << "*)";
} }
...@@ -681,7 +677,6 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) ...@@ -681,7 +677,6 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
auto it = alloc_storage_scope_.find(op->buffer_var.get()); auto it = alloc_storage_scope_.find(op->buffer_var.get());
if (it != alloc_storage_scope_.end()) { if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, value_temp); PrintStorageScope(it->second, value_temp);
value_temp << ' ';
} }
} }
PrintType(elem_type, value_temp); PrintType(elem_type, value_temp);
...@@ -731,7 +726,6 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { ...@@ -731,7 +726,6 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
auto it = alloc_storage_scope_.find(op->buffer_var.get()); auto it = alloc_storage_scope_.find(op->buffer_var.get());
if (it != alloc_storage_scope_.end()) { if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream); PrintStorageScope(it->second, stream);
stream << ' ';
} }
} }
PrintType(elem_type, stream); PrintType(elem_type, stream);
...@@ -823,10 +817,8 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) { ...@@ -823,10 +817,8 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) {
const VarNode* buffer = op->buffer_var.as<VarNode>(); const VarNode* buffer = op->buffer_var.as<VarNode>();
std::string scope = alloc_storage_scope_.at(buffer); std::string scope = alloc_storage_scope_.at(buffer);
PrintStorageScope(scope, stream); PrintStorageScope(scope, stream);
stream << ' ';
PrintType(op->dtype, stream); PrintType(op->dtype, stream);
stream << ' '<< vid << '[' stream << ' ' << vid << '[' << constant_size << "];\n";
<< constant_size << "];\n";
RegisterHandleType(op->buffer_var.get(), op->dtype); RegisterHandleType(op->buffer_var.get(), op->dtype);
this->PrintStmt(op->body); this->PrintStmt(op->body);
......
...@@ -257,29 +257,6 @@ class CodeGenC : ...@@ -257,29 +257,6 @@ class CodeGenC :
/*! \brief the data type of allocated buffers */ /*! \brief the data type of allocated buffers */
std::unordered_map<const VarNode*, DataType> handle_data_type_; std::unordered_map<const VarNode*, DataType> handle_data_type_;
/*!
* \brief A RAII utility class for emitting code in a scoped region.
*/
class EnterScopeRAII {
// The codegen context.
CodeGenC* cg;
// The new scope level.
int scope;
public:
explicit EnterScopeRAII(CodeGenC* cg) : cg(cg) {
cg->PrintIndent();
cg->stream << "{\n";
scope = cg->BeginScope();
}
~EnterScopeRAII() {
cg->EndScope(scope);
cg->PrintIndent();
cg->stream << "}\n";
}
};
private: private:
/*! \brief whether to print in SSA form */ /*! \brief whether to print in SSA form */
bool print_ssa_form_{false}; bool print_ssa_form_{false};
......
...@@ -242,8 +242,6 @@ void CodeGenCUDA::PrintVecBinaryOp( ...@@ -242,8 +242,6 @@ void CodeGenCUDA::PrintVecBinaryOp(
this->PrintType(t, stream); this->PrintType(t, stream);
stream << ' ' << sret << ";\n"; stream << ' ' << sret << ";\n";
{ {
EnterScopeRAII scope(this);
// Unpack into individual ops. // Unpack into individual ops.
std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype()); std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype()); std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());
...@@ -350,7 +348,7 @@ void CodeGenCUDA::PrintStorageScope( ...@@ -350,7 +348,7 @@ void CodeGenCUDA::PrintStorageScope(
const std::string& scope, std::ostream& os) { // NOLINT(*) const std::string& scope, std::ostream& os) { // NOLINT(*)
CHECK_NE(scope, "global"); CHECK_NE(scope, "global");
if (scope == "shared") { if (scope == "shared") {
os << "__shared__"; os << "__shared__ ";
} }
} }
...@@ -370,7 +368,6 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { ...@@ -370,7 +368,6 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
this->PrintType(target_ty, stream); this->PrintType(target_ty, stream);
stream << ' ' << sret << ";\n"; stream << ' ' << sret << ";\n";
{ {
EnterScopeRAII scope(this);
std::string src = SSAGetID(PrintExpr(op->value), from_ty); std::string src = SSAGetID(PrintExpr(op->value), from_ty);
for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
std::ostringstream val; std::ostringstream val;
...@@ -470,8 +467,6 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { ...@@ -470,8 +467,6 @@ void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
this->PrintType(op->dtype, stream); this->PrintType(op->dtype, stream);
stream << ' ' << sret << ";\n"; stream << ' ' << sret << ";\n";
{ {
EnterScopeRAII scope(this);
// Load arguments. // Load arguments.
std::vector<std::string> sargs; std::vector<std::string> sargs;
for (size_t i = 0; i < op->args.size(); ++i) { for (size_t i = 0; i < op->args.size(); ++i) {
...@@ -541,7 +536,6 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { ...@@ -541,7 +536,6 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
PrintWmmaScope(scope, op->dtype, buffer, stream); PrintWmmaScope(scope, op->dtype, buffer, stream);
} else { } else {
PrintStorageScope(scope, stream); PrintStorageScope(scope, stream);
stream << ' ';
PrintType(op->dtype, stream); PrintType(op->dtype, stream);
} }
if ((op->dtype == DataType::Int(4) || if ((op->dtype == DataType::Int(4) ||
...@@ -657,8 +651,6 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) { ...@@ -657,8 +651,6 @@ void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream &os) {
this->PrintType(op->dtype, stream); this->PrintType(op->dtype, stream);
stream << ' ' << r_var << ";\n"; stream << ' ' << r_var << ";\n";
{ {
EnterScopeRAII scope(this);
std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype); std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype);
std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype); std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype);
std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype); std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype);
......
...@@ -74,7 +74,6 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { ...@@ -74,7 +74,6 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
if (it != alloc_storage_scope_.end()) { if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream); PrintStorageScope(it->second, stream);
} }
stream << ' ';
PrintType(GetType(v), stream); PrintType(GetType(v), stream);
// Register handle data type // Register handle data type
// TODO(tvm-team): consider simply keep type info in the // TODO(tvm-team): consider simply keep type info in the
...@@ -236,11 +235,11 @@ void CodeGenMetal::PrintVecElemStore(const std::string& vec, ...@@ -236,11 +235,11 @@ void CodeGenMetal::PrintVecElemStore(const std::string& vec,
void CodeGenMetal::PrintStorageScope( void CodeGenMetal::PrintStorageScope(
const std::string& scope, std::ostream& os) { // NOLINT(*) const std::string& scope, std::ostream& os) { // NOLINT(*)
if (scope == "global") { if (scope == "global") {
os << "device"; os << "device ";
} else if (scope == "shared") { } else if (scope == "shared") {
os << "threadgroup"; os << "threadgroup ";
} else { } else {
os << "thread"; os << "thread ";
} }
} }
......
...@@ -150,7 +150,6 @@ void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t, ...@@ -150,7 +150,6 @@ void CodeGenOpenCL::PrintVecAddr(const VarNode* buffer, DataType t,
if (it != alloc_storage_scope_.end()) { if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, os); PrintStorageScope(it->second, os);
} }
os << ' ';
PrintType(t.element_of(), os); PrintType(t.element_of(), os);
os << "*)"; os << "*)";
} }
...@@ -191,9 +190,9 @@ void CodeGenOpenCL::PrintStorageSync(const CallNode* op) { ...@@ -191,9 +190,9 @@ void CodeGenOpenCL::PrintStorageSync(const CallNode* op) {
void CodeGenOpenCL::PrintStorageScope( void CodeGenOpenCL::PrintStorageScope(
const std::string& scope, std::ostream& os) { // NOLINT(*) const std::string& scope, std::ostream& os) { // NOLINT(*)
if (scope == "global") { if (scope == "global") {
os << "__global"; os << "__global ";
} else if (scope == "shared") { } else if (scope == "shared") {
os << "__local"; os << "__local ";
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment