Commit 9082a93b by Tianqi Chen Committed by GitHub

Make target and build module more pythonic (#1089)

parent 819728db
......@@ -24,6 +24,8 @@ class TargetNode : public Node {
public:
/*! \brief The name of the target device */
std::string target_name;
/*! \brief The name of the target device */
std::string device_name;
/*! \brief The type of the target device */
int device_type;
/*! \brief The maximum threads that a schedule should use for this device */
......@@ -42,6 +44,7 @@ class TargetNode : public Node {
void VisitAttrs(AttrVisitor* v) final {
v->Visit("target_name", &target_name);
v->Visit("device_name", &device_name);
v->Visit("device_type", &device_type);
v->Visit("max_num_threads", &max_num_threads);
v->Visit("thread_warp_size", &thread_warp_size);
......
......@@ -150,6 +150,13 @@ class BuildConfig(NodeBase):
result += [(phase, func)]
return result
@add_lower_pass.setter
def add_lower_pass(self, value):
add_lower_pass_args = []
for x in value:
add_lower_pass_args += [x[0], x[1]]
_api_internal._BuildConfigSetAddLowerPass(self, *add_lower_pass_args)
def __enter__(self):
# pylint: disable=protected-access
_api_internal._EnterBuildConfigScope(self)
......@@ -168,9 +175,12 @@ class BuildConfig(NodeBase):
"'%s' object cannot set attribute '%s'" % (str(type(self)), name))
return super(BuildConfig, self).__setattr__(name, value)
def current_build_config():
"""Get the current build configuration."""
return _api_internal._GetCurrentBuildConfig()
def build_config(**kwargs):
"""Configure the build behavior by setting config variables.
......@@ -230,10 +240,7 @@ def build_config(**kwargs):
config = make.node("BuildConfig", **node_args)
if "add_lower_pass" in kwargs:
add_lower_pass_args = []
for x in kwargs["add_lower_pass"]:
add_lower_pass_args += [x[0], x[1]]
_api_internal._BuildConfigSetAddLowerPass(config, *add_lower_pass_args)
config.add_lower_pass = kwargs["add_lower_pass"]
return config
......
......@@ -166,6 +166,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
conn, addr, opts = _accept_conn(sock, tracker_conn)
except (socket.error, IOError):
# retry when tracker is dropped
if tracker_conn:
tracker_conn.close()
tracker_conn = None
continue
......
......@@ -37,8 +37,6 @@ Target CreateTarget(const std::string& target_name,
t->target_name = target_name;
std::string device_name = "";
std::string libs_flag = "-libs=";
std::string device_flag = "-device=";
for (auto& item : options) {
......@@ -51,12 +49,12 @@ Target CreateTarget(const std::string& target_name,
t->libs_array.push_back(ir::StringImm::make(lib_item));
}
} else if (item.find(device_flag) == 0) {
device_name = item.substr(device_flag.length());
t->device_name = item.substr(device_flag.length());
}
}
if (device_name.length() > 0) {
t->keys_array.push_back(ir::StringImm::make(device_name));
if (t->device_name.length() > 0) {
t->keys_array.push_back(ir::StringImm::make(t->device_name));
}
t->device_type = kDLCPU;
t->thread_warp_size = 1;
......@@ -78,7 +76,7 @@ Target CreateTarget(const std::string& target_name,
t->keys_array.push_back(ir::StringImm::make("rocm"));
t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 256;
if (device_name == "intel_gpu") {
if (t->device_name == "intel_gpu") {
t->thread_warp_size = 16;
}
} else if (target_name == "metal" || target_name == "vulkan") {
......
......@@ -36,6 +36,7 @@ def test_target_dispatch():
assert tvm.target.current_target() == None
def test_target_string_parse():
target = tvm.target.create("cuda -libs=cublas,cudnn")
......@@ -45,6 +46,9 @@ def test_target_string_parse():
assert target.libs == ['cublas', 'cudnn']
assert str(target) == str(tvm.target.cuda("-libs=cublas,cudnn"))
assert tvm.target.intel_gpu().device_name == "intel_gpu"
if __name__ == "__main__":
test_target_dispatch()
test_target_string_parse()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment