Commit d8be197d by Lianmin Zheng Committed by Tianqi Chen

[CODEGEN] fix & improments in codegen (#745)

* [CODEGEN] update codegen for vector operation

* update comment, fix for metal

* fix some bugs in codegen

* use 'restrict' in every argument

* fix

* fix
parent 3d5032ae
...@@ -228,7 +228,7 @@ void CodeGenC::RegisterHandleType(const Variable* buf_var, Type t) { ...@@ -228,7 +228,7 @@ void CodeGenC::RegisterHandleType(const Variable* buf_var, Type t) {
void CodeGenC::PrintVecElemLoad(const std::string& vec, void CodeGenC::PrintVecElemLoad(const std::string& vec,
Type t, int i, Type t, int i,
std::ostream& os) { // NOLINT(*) std::ostream& os) { // NOLINT(*)
os << vec << ".s" << std::hex << i; os << vec << ".s" << std::hex << i << std::dec;
} }
void CodeGenC::PrintVecElemStore(const std::string& vec, void CodeGenC::PrintVecElemStore(const std::string& vec,
...@@ -236,7 +236,7 @@ void CodeGenC::PrintVecElemStore(const std::string& vec, ...@@ -236,7 +236,7 @@ void CodeGenC::PrintVecElemStore(const std::string& vec,
const std::string& value) { const std::string& value) {
this->PrintIndent(); this->PrintIndent();
stream << vec << ".s" << std::hex << i stream << vec << ".s" << std::hex << i
<< " = " << value << ";\n"; << " = " << value << ";\n" << std::dec;
} }
std::string CodeGenC::GetVecLoad( std::string CodeGenC::GetVecLoad(
...@@ -583,6 +583,13 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) ...@@ -583,6 +583,13 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
std::ostringstream value_temp; std::ostringstream value_temp;
if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) { if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) {
value_temp << "(("; value_temp << "((";
if (op->buffer_var.get()->type.is_handle()) {
auto it = alloc_storage_scope_.find(op->buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, value_temp);
value_temp << ' ';
}
}
PrintType(elem_type, value_temp); PrintType(elem_type, value_temp);
value_temp << "*)" << vid << ')'; value_temp << "*)" << vid << ')';
} else { } else {
...@@ -627,6 +634,13 @@ void CodeGenC::VisitStmt_(const Store* op) { ...@@ -627,6 +634,13 @@ void CodeGenC::VisitStmt_(const Store* op) {
Type elem_type = t.element_of(); Type elem_type = t.element_of();
if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) { if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) {
stream << "(("; stream << "((";
if (op->buffer_var.get()->type.is_handle()) {
auto it = alloc_storage_scope_.find(op->buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
stream << ' ';
}
}
PrintType(elem_type, stream); PrintType(elem_type, stream);
stream << "*)" << vid << ')'; stream << "*)" << vid << ')';
} else { } else {
......
...@@ -177,14 +177,14 @@ void CodeGenOpenCL::PrintStorageScope( ...@@ -177,14 +177,14 @@ void CodeGenOpenCL::PrintStorageScope(
void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*) void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value); std::string v = PrintExpr(op->value);
os << '('; os << "((";
PrintType(op->type, os); PrintType(op->type, os);
os << ")("; os << ")(";
for (int i = 0; i < op->lanes; ++i) { for (int i = 0; i < op->lanes; ++i) {
if (i != 0) os << ", "; if (i != 0) os << ", ";
os << v; os << v;
} }
os << ')'; os << "))";
} }
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -154,8 +154,8 @@ inline Type APIType(Type t) { ...@@ -154,8 +154,8 @@ inline Type APIType(Type t) {
inline int GetTempAllocaAlignment(Type type, int32_t const_size) { inline int GetTempAllocaAlignment(Type type, int32_t const_size) {
int align = runtime::kTempAllocaAlignment; int align = runtime::kTempAllocaAlignment;
if (const_size > 0) { if (const_size > 0) {
const_size = const_size * type.bits() * type.lanes() / 8; int64_t const_s = static_cast<int64_t>(const_size) * type.bits() * type.lanes() / 8;
while (align > const_size) { while (align > const_s) {
align = align / 2; align = align / 2;
} }
} }
......
...@@ -191,6 +191,9 @@ class HostDeviceSplitter : public IRMutator { ...@@ -191,6 +191,9 @@ class HostDeviceSplitter : public IRMutator {
auto it = handle_data_type_.find(v.get()); auto it = handle_data_type_.find(v.get());
if (it != handle_data_type_.end()) { if (it != handle_data_type_.end()) {
n->handle_data_type.Set(v, it->second); n->handle_data_type.Set(v, it->second);
} else {
// int32 as a placeholder
n->handle_data_type.Set(v, make_const(UInt(32), 0));
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment