Commit d8be197d by Lianmin Zheng Committed by Tianqi Chen

[CODEGEN] fix & improments in codegen (#745)

* [CODEGEN] update codegen for vector operation

* update comment, fix for metal

* fix some bugs in codegen

* use 'restrict' in every argument

* fix

* fix
parent 3d5032ae
......@@ -228,7 +228,7 @@ void CodeGenC::RegisterHandleType(const Variable* buf_var, Type t) {
void CodeGenC::PrintVecElemLoad(const std::string& vec,
Type t, int i,
std::ostream& os) { // NOLINT(*)
os << vec << ".s" << std::hex << i;
os << vec << ".s" << std::hex << i << std::dec;
}
void CodeGenC::PrintVecElemStore(const std::string& vec,
......@@ -236,7 +236,7 @@ void CodeGenC::PrintVecElemStore(const std::string& vec,
const std::string& value) {
this->PrintIndent();
stream << vec << ".s" << std::hex << i
<< " = " << value << ";\n";
<< " = " << value << ";\n" << std::dec;
}
std::string CodeGenC::GetVecLoad(
......@@ -583,6 +583,13 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
std::ostringstream value_temp;
if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) {
value_temp << "((";
if (op->buffer_var.get()->type.is_handle()) {
auto it = alloc_storage_scope_.find(op->buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, value_temp);
value_temp << ' ';
}
}
PrintType(elem_type, value_temp);
value_temp << "*)" << vid << ')';
} else {
......@@ -627,6 +634,13 @@ void CodeGenC::VisitStmt_(const Store* op) {
Type elem_type = t.element_of();
if (!HandleTypeMatch(op->buffer_var.get(), elem_type)) {
stream << "((";
if (op->buffer_var.get()->type.is_handle()) {
auto it = alloc_storage_scope_.find(op->buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
stream << ' ';
}
}
PrintType(elem_type, stream);
stream << "*)" << vid << ')';
} else {
......
......@@ -177,14 +177,14 @@ void CodeGenOpenCL::PrintStorageScope(
void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
os << '(';
os << "((";
PrintType(op->type, os);
os << ")(";
for (int i = 0; i < op->lanes; ++i) {
if (i != 0) os << ", ";
os << v;
}
os << ')';
os << "))";
}
} // namespace codegen
} // namespace tvm
......@@ -154,8 +154,8 @@ inline Type APIType(Type t) {
inline int GetTempAllocaAlignment(Type type, int32_t const_size) {
int align = runtime::kTempAllocaAlignment;
if (const_size > 0) {
const_size = const_size * type.bits() * type.lanes() / 8;
while (align > const_size) {
int64_t const_s = static_cast<int64_t>(const_size) * type.bits() * type.lanes() / 8;
while (align > const_s) {
align = align / 2;
}
}
......
......@@ -191,6 +191,9 @@ class HostDeviceSplitter : public IRMutator {
auto it = handle_data_type_.find(v.get());
if (it != handle_data_type_.end()) {
n->handle_data_type.Set(v, it->second);
} else {
// int32 as a placeholder
n->handle_data_type.Set(v, make_const(UInt(32), 0));
}
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment