Commit 877254f4 by Tianqi Chen Committed by GitHub

Fix metal backward compatibility (#1105)

parent ce02ee3b
...@@ -83,8 +83,11 @@ class MetalModuleNode final :public runtime::ModuleNode { ...@@ -83,8 +83,11 @@ class MetalModuleNode final :public runtime::ModuleNode {
if (e.lib == nil) { if (e.lib == nil) {
if (fmt_ == "metal") { if (fmt_ == "metal") {
MTLCompileOptions *opts = [MTLCompileOptions alloc]; MTLCompileOptions *opts = [MTLCompileOptions alloc];
opts.languageVersion = MTLLanguageVersion2_0; // Use the default setting for now.
opts.fastMathEnabled = YES; // by default most recent version is used.
// opts.languageVersion = MTLLanguageVersion2_0;
// opts.fastMathEnabled = YES;
opts = nil;
e.lib = [ e.lib = [
w->devices[device_id] w->devices[device_id]
newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()] newLibraryWithSource:[NSString stringWithUTF8String:data_.c_str()]
......
...@@ -98,8 +98,8 @@ def test_popcount(): ...@@ -98,8 +98,8 @@ def test_popcount():
check_device("llvm") check_device("llvm")
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
check_device("metal")
if dtype == "uint32": if dtype == "uint32":
check_device("metal")
check_device("vulkan") check_device("vulkan")
run('uint32') run('uint32')
run('uint64') run('uint64')
...@@ -146,9 +146,10 @@ def test_add(): ...@@ -146,9 +146,10 @@ def test_add():
c.asnumpy(), a.asnumpy() + b.asnumpy(), rtol=1e-6) c.asnumpy(), a.asnumpy() + b.asnumpy(), rtol=1e-6)
check_device("opencl") check_device("opencl")
check_device("metal")
check_device("cuda") check_device("cuda")
check_device("vulkan") if dtype == "float32":
check_device("metal")
check_device("vulkan")
run("float32") run("float32")
run("int32") run("int32")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment