Commit 0220abba by Tianqi Chen Committed by GitHub

[METAL] use 32bit indexing for metal until we have a bound adapted pass (#462)

* [METAL] use 32bit indexing for metal until we have a bound adapted pass

* fix lint
parent 324fe165
...@@ -101,21 +101,21 @@ void CodeGenMetal::AddFunction(LoweredFunc f) { ...@@ -101,21 +101,21 @@ void CodeGenMetal::AddFunction(LoweredFunc f) {
if (work_dim != 0) { if (work_dim != 0) {
// use ushort by default for now // use ushort by default for now
stream << " "; stream << " ";
PrintType(UInt(16, work_dim), stream); PrintType(UInt(thread_index_bits_, work_dim), stream);
stream << " blockIdx [[threadgroup_position_in_grid]],\n"; stream << " blockIdx [[threadgroup_position_in_grid]],\n";
stream << " "; stream << " ";
PrintType(UInt(16, work_dim), stream); PrintType(UInt(thread_index_bits_, work_dim), stream);
stream << " threadIdx [[thread_position_in_threadgroup]]\n"; stream << " threadIdx [[thread_position_in_threadgroup]]\n";
} }
// bind thread axis // bind thread axis
for (IterVar iv : f->thread_axis) { for (IterVar iv : f->thread_axis) {
CHECK(!var_idmap_.count(iv->var.get())); CHECK(!var_idmap_.count(iv->var.get()));
std::string vname = iv->thread_tag;
if (work_dim <= 1) { if (work_dim <= 1) {
var_idmap_[iv->var.get()] = vname = vname.substr(0, iv->thread_tag.length() - 2);
iv->thread_tag.substr(0, iv->thread_tag.length() - 2);
} else {
var_idmap_[iv->var.get()] = iv->thread_tag;
} }
var_idmap_[iv->var.get()] =
CastFromTo(vname, UInt(thread_index_bits_), iv->var.type());
} }
// the function scope. // the function scope.
stream << ") {\n"; stream << ") {\n";
...@@ -129,7 +129,7 @@ void CodeGenMetal::AddFunction(LoweredFunc f) { ...@@ -129,7 +129,7 @@ void CodeGenMetal::AddFunction(LoweredFunc f) {
void CodeGenMetal::BindThreadIndex(const IterVar& iv) { void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
CHECK(!var_idmap_.count(iv->var.get())); CHECK(!var_idmap_.count(iv->var.get()));
var_idmap_[iv->var.get()] = var_idmap_[iv->var.get()] =
CastFromTo(iv->thread_tag, UInt(16), iv->var.type()); CastFromTo(iv->thread_tag, UInt(thread_index_bits_), iv->var.type());
} }
void CodeGenMetal::PrintType(Type t, std::ostream& os) const { // NOLINT(*) void CodeGenMetal::PrintType(Type t, std::ostream& os) const { // NOLINT(*)
......
...@@ -27,6 +27,9 @@ class CodeGenMetal final : public CodeGenC { ...@@ -27,6 +27,9 @@ class CodeGenMetal final : public CodeGenC {
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// overload visitor // overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
private:
int thread_index_bits_{32};
}; };
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment