Commit 9082a93b by Tianqi Chen Committed by GitHub

Make target and build module more pythonic (#1089)

parent 819728db
...@@ -24,6 +24,8 @@ class TargetNode : public Node { ...@@ -24,6 +24,8 @@ class TargetNode : public Node {
public: public:
/*! \brief The name of the target device */ /*! \brief The name of the target device */
std::string target_name; std::string target_name;
/*! \brief The name of the target device */
std::string device_name;
/*! \brief The type of the target device */ /*! \brief The type of the target device */
int device_type; int device_type;
/*! \brief The maximum threads that a schedule should use for this device */ /*! \brief The maximum threads that a schedule should use for this device */
...@@ -42,6 +44,7 @@ class TargetNode : public Node { ...@@ -42,6 +44,7 @@ class TargetNode : public Node {
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("target_name", &target_name); v->Visit("target_name", &target_name);
v->Visit("device_name", &device_name);
v->Visit("device_type", &device_type); v->Visit("device_type", &device_type);
v->Visit("max_num_threads", &max_num_threads); v->Visit("max_num_threads", &max_num_threads);
v->Visit("thread_warp_size", &thread_warp_size); v->Visit("thread_warp_size", &thread_warp_size);
......
...@@ -150,6 +150,13 @@ class BuildConfig(NodeBase): ...@@ -150,6 +150,13 @@ class BuildConfig(NodeBase):
result += [(phase, func)] result += [(phase, func)]
return result return result
@add_lower_pass.setter
def add_lower_pass(self, value):
add_lower_pass_args = []
for x in value:
add_lower_pass_args += [x[0], x[1]]
_api_internal._BuildConfigSetAddLowerPass(self, *add_lower_pass_args)
def __enter__(self): def __enter__(self):
# pylint: disable=protected-access # pylint: disable=protected-access
_api_internal._EnterBuildConfigScope(self) _api_internal._EnterBuildConfigScope(self)
...@@ -168,9 +175,12 @@ class BuildConfig(NodeBase): ...@@ -168,9 +175,12 @@ class BuildConfig(NodeBase):
"'%s' object cannot set attribute '%s'" % (str(type(self)), name)) "'%s' object cannot set attribute '%s'" % (str(type(self)), name))
return super(BuildConfig, self).__setattr__(name, value) return super(BuildConfig, self).__setattr__(name, value)
def current_build_config(): def current_build_config():
"""Get the current build configuration."""
return _api_internal._GetCurrentBuildConfig() return _api_internal._GetCurrentBuildConfig()
def build_config(**kwargs): def build_config(**kwargs):
"""Configure the build behavior by setting config variables. """Configure the build behavior by setting config variables.
...@@ -230,10 +240,7 @@ def build_config(**kwargs): ...@@ -230,10 +240,7 @@ def build_config(**kwargs):
config = make.node("BuildConfig", **node_args) config = make.node("BuildConfig", **node_args)
if "add_lower_pass" in kwargs: if "add_lower_pass" in kwargs:
add_lower_pass_args = [] config.add_lower_pass = kwargs["add_lower_pass"]
for x in kwargs["add_lower_pass"]:
add_lower_pass_args += [x[0], x[1]]
_api_internal._BuildConfigSetAddLowerPass(config, *add_lower_pass_args)
return config return config
......
...@@ -166,6 +166,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr): ...@@ -166,6 +166,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
conn, addr, opts = _accept_conn(sock, tracker_conn) conn, addr, opts = _accept_conn(sock, tracker_conn)
except (socket.error, IOError): except (socket.error, IOError):
# retry when tracker is dropped # retry when tracker is dropped
if tracker_conn:
tracker_conn.close() tracker_conn.close()
tracker_conn = None tracker_conn = None
continue continue
......
...@@ -37,8 +37,6 @@ Target CreateTarget(const std::string& target_name, ...@@ -37,8 +37,6 @@ Target CreateTarget(const std::string& target_name,
t->target_name = target_name; t->target_name = target_name;
std::string device_name = "";
std::string libs_flag = "-libs="; std::string libs_flag = "-libs=";
std::string device_flag = "-device="; std::string device_flag = "-device=";
for (auto& item : options) { for (auto& item : options) {
...@@ -51,12 +49,12 @@ Target CreateTarget(const std::string& target_name, ...@@ -51,12 +49,12 @@ Target CreateTarget(const std::string& target_name,
t->libs_array.push_back(ir::StringImm::make(lib_item)); t->libs_array.push_back(ir::StringImm::make(lib_item));
} }
} else if (item.find(device_flag) == 0) { } else if (item.find(device_flag) == 0) {
device_name = item.substr(device_flag.length()); t->device_name = item.substr(device_flag.length());
} }
} }
if (device_name.length() > 0) { if (t->device_name.length() > 0) {
t->keys_array.push_back(ir::StringImm::make(device_name)); t->keys_array.push_back(ir::StringImm::make(t->device_name));
} }
t->device_type = kDLCPU; t->device_type = kDLCPU;
t->thread_warp_size = 1; t->thread_warp_size = 1;
...@@ -78,7 +76,7 @@ Target CreateTarget(const std::string& target_name, ...@@ -78,7 +76,7 @@ Target CreateTarget(const std::string& target_name,
t->keys_array.push_back(ir::StringImm::make("rocm")); t->keys_array.push_back(ir::StringImm::make("rocm"));
t->keys_array.push_back(ir::StringImm::make("gpu")); t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 256; t->max_num_threads = 256;
if (device_name == "intel_gpu") { if (t->device_name == "intel_gpu") {
t->thread_warp_size = 16; t->thread_warp_size = 16;
} }
} else if (target_name == "metal" || target_name == "vulkan") { } else if (target_name == "metal" || target_name == "vulkan") {
......
...@@ -36,6 +36,7 @@ def test_target_dispatch(): ...@@ -36,6 +36,7 @@ def test_target_dispatch():
assert tvm.target.current_target() == None assert tvm.target.current_target() == None
def test_target_string_parse(): def test_target_string_parse():
target = tvm.target.create("cuda -libs=cublas,cudnn") target = tvm.target.create("cuda -libs=cublas,cudnn")
...@@ -45,6 +46,9 @@ def test_target_string_parse(): ...@@ -45,6 +46,9 @@ def test_target_string_parse():
assert target.libs == ['cublas', 'cudnn'] assert target.libs == ['cublas', 'cudnn']
assert str(target) == str(tvm.target.cuda("-libs=cublas,cudnn")) assert str(target) == str(tvm.target.cuda("-libs=cublas,cudnn"))
assert tvm.target.intel_gpu().device_name == "intel_gpu"
if __name__ == "__main__": if __name__ == "__main__":
test_target_dispatch() test_target_dispatch()
test_target_string_parse() test_target_string_parse()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment