Commit 8065f2e0 by Lianmin Zheng Committed by Tianqi Chen

support string option when create cuda/rocm/... target (#1071)

parent 2f87ddc9
...@@ -349,10 +349,10 @@ def cuda(options=None): ...@@ -349,10 +349,10 @@ def cuda(options=None):
Parameters Parameters
---------- ----------
options : list of str options : str or list of str
Additional options Additional options
""" """
options = options if options else [] options = _merge_opts([], options)
return _api_internal._TargetCreate("cuda", *options) return _api_internal._TargetCreate("cuda", *options)
...@@ -361,10 +361,10 @@ def rocm(options=None): ...@@ -361,10 +361,10 @@ def rocm(options=None):
Parameters Parameters
---------- ----------
options : list of str options : str or list of str
Additional options Additional options
""" """
options = options if options else [] options = _merge_opts([], options)
return _api_internal._TargetCreate("rocm", *options) return _api_internal._TargetCreate("rocm", *options)
...@@ -373,7 +373,7 @@ def rasp(options=None): ...@@ -373,7 +373,7 @@ def rasp(options=None):
Parameters Parameters
---------- ----------
options : list of str options : str or list of str
Additional options Additional options
""" """
opts = ["-device=rasp", opts = ["-device=rasp",
...@@ -389,7 +389,7 @@ def mali(options=None): ...@@ -389,7 +389,7 @@ def mali(options=None):
Parameters Parameters
---------- ----------
options : list of str options : str or list of str
Additional options Additional options
""" """
opts = ["-device=mali"] opts = ["-device=mali"]
...@@ -402,10 +402,10 @@ def opengl(options=None): ...@@ -402,10 +402,10 @@ def opengl(options=None):
Parameters Parameters
---------- ----------
options : list of str options : str or list of str
Additional options Additional options
""" """
options = options if options else [] options = _merge_opts([], options)
return _api_internal._TargetCreate("opengl", *options) return _api_internal._TargetCreate("opengl", *options)
......
...@@ -43,6 +43,7 @@ def test_target_string_parse(): ...@@ -43,6 +43,7 @@ def test_target_string_parse():
assert target.options == ['-libs=cublas,cudnn'] assert target.options == ['-libs=cublas,cudnn']
assert target.keys == ['cuda', 'gpu'] assert target.keys == ['cuda', 'gpu']
assert target.libs == ['cublas', 'cudnn'] assert target.libs == ['cublas', 'cudnn']
assert str(target) == str(tvm.target.cuda("-libs=cublas,cudnn"))
if __name__ == "__main__": if __name__ == "__main__":
test_target_dispatch() test_target_dispatch()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment