Commit 7d2654c2 by Tianqi Chen Committed by GitHub

[CODEGEN] Fix vector element access in metal (#872)

parent b5bd923a
...@@ -59,6 +59,7 @@ def test_rpc_module(): ...@@ -59,6 +59,7 @@ def test_rpc_module():
# Start RPC test server that contains the compiled library. # Start RPC test server that contains the compiled library.
server = xcode.popen_test_rpc(proxy_host, proxy_port, key, server = xcode.popen_test_rpc(proxy_host, proxy_port, key,
destination=destination, destination=destination,
options=['-quiet'],
libs=[path_dso1, path_dso2]) libs=[path_dso1, path_dso2])
# connect to the proxy # connect to the proxy
......
...@@ -201,5 +201,11 @@ def popen_test_rpc(host, ...@@ -201,5 +201,11 @@ def popen_test_rpc(host,
if options: if options:
cmd += options cmd += options
cmd += ["test"] cmd += ["test"]
proc = subprocess.Popen(cmd) if "-quiet" in options:
with open(os.devnull, 'w') as devnull:
proc = subprocess.Popen(cmd,
stderr=subprocess.STDOUT,
stdout=devnull)
else:
proc = subprocess.Popen(cmd)
return proc return proc
...@@ -117,7 +117,7 @@ class Target(object): ...@@ -117,7 +117,7 @@ class Target(object):
self.keys += ("rocm", "gpu") self.keys += ("rocm", "gpu")
self.max_num_threads = 256 self.max_num_threads = 256
elif target_name in ("metal", "vulkan"): elif target_name in ("metal", "vulkan"):
self.keys += ("gpu",) self.keys += (target_name, "gpu",)
self.max_num_threads = 256 self.max_num_threads = 256
elif target_name in ("opengl",): elif target_name in ("opengl",):
self.keys += ("opengl",) self.keys += ("opengl",)
......
...@@ -186,6 +186,20 @@ void CodeGenMetal::PrintStorageSync(const Call* op) { ...@@ -186,6 +186,20 @@ void CodeGenMetal::PrintStorageSync(const Call* op) {
} }
} }
void CodeGenMetal::PrintVecElemLoad(const std::string& vec,
Type t, int i,
std::ostream& os) { // NOLINT(*)
os << vec << "[" << i << "]";
}
void CodeGenMetal::PrintVecElemStore(const std::string& vec,
Type t, int i,
const std::string& value) {
this->PrintIndent();
stream << vec << "[" << i << "]"
<< " = " << value << ";\n";
}
void CodeGenMetal::PrintStorageScope( void CodeGenMetal::PrintStorageScope(
const std::string& scope, std::ostream& os) { // NOLINT(*) const std::string& scope, std::ostream& os) { // NOLINT(*)
if (scope == "global") { if (scope == "global") {
......
...@@ -25,6 +25,12 @@ class CodeGenMetal final : public CodeGenC { ...@@ -25,6 +25,12 @@ class CodeGenMetal final : public CodeGenC {
void PrintStorageSync(const Call* op) final; // NOLINT(*) void PrintStorageSync(const Call* op) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) final; // NOLINT(*) void PrintType(Type t, std::ostream& os) final; // NOLINT(*)
void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
// print load of single element
void PrintVecElemLoad(
const std::string& vec, Type t, int i, std::ostream& os) final; // NOLINT(*)
// print store of single element.
void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value) final;
// overload visitor // overload visitor
void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const Broadcast* op, std::ostream& os) final; // NOLINT(*)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment