Commit 0d611134 by Andrew Tulloch Committed by Tianqi Chen

Metal reinterpret fix (#3706)

parent c8654e2a
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -246,6 +246,19 @@ void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLI ...@@ -246,6 +246,19 @@ void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLI
os << ')'; os << ')';
} }
void CodeGenMetal::VisitExpr_(const Call* op, std::ostream& os) { // NOLINT(*)
if (op->is_intrinsic(Call::reinterpret)) {
// generate as_type<TYPE>(ARG)
os << "(as_type<";
this->PrintType(op->type, os);
os << ">(";
this->PrintExpr(op->args[0], os);
os << "))";
} else {
CodeGenC::VisitExpr_(op, os);
}
}
runtime::Module BuildMetal(Array<LoweredFunc> funcs) { runtime::Module BuildMetal(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry; using tvm::runtime::Registry;
bool output_ssa = false; bool output_ssa = false;
......
...@@ -53,6 +53,9 @@ class CodeGenMetal final : public CodeGenC { ...@@ -53,6 +53,9 @@ class CodeGenMetal final : public CodeGenC {
// overload visitor // overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
// overload visitor
void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*)
private: private:
int thread_index_bits_{32}; int thread_index_bits_{32};
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment