Commit 0d611134 by Andrew Tulloch Committed by Tianqi Chen

Metal reinterpret fix (#3706)

parent c8654e2a
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -246,6 +246,19 @@ void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLI
os << ')';
}
void CodeGenMetal::VisitExpr_(const Call* op, std::ostream& os) { // NOLINT(*)
if (op->is_intrinsic(Call::reinterpret)) {
// generate as_type<TYPE>(ARG)
os << "(as_type<";
this->PrintType(op->type, os);
os << ">(";
this->PrintExpr(op->args[0], os);
os << "))";
} else {
CodeGenC::VisitExpr_(op, os);
}
}
runtime::Module BuildMetal(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
bool output_ssa = false;
......
......@@ -53,6 +53,9 @@ class CodeGenMetal final : public CodeGenC {
// overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
// overload visitor
void VisitExpr_(const Call* op, std::ostream& os) final; // NOLINT(*)
private:
int thread_index_bits_{32};
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment