Commit aaf7ff04 by alex-weaver Committed by Tianqi Chen

Move BuildConfig context stack to C++ (#1025)

parent 7b098c9a
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility>
#include "./runtime/packed_func.h" #include "./runtime/packed_func.h"
#include "./schedule_pass.h" #include "./schedule_pass.h"
#include "./lowered_func.h" #include "./lowered_func.h"
...@@ -203,6 +204,12 @@ class BuildConfigNode : public Node { ...@@ -203,6 +204,12 @@ class BuildConfigNode : public Node {
/*! \brief Whether to partition const loop */ /*! \brief Whether to partition const loop */
bool partition_const_loop = false; bool partition_const_loop = false;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
std::vector< std::pair<int, PackedFunc> > add_lower_pass;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("data_alignment", &data_alignment); v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor); v->Visit("offset_factor", &offset_factor);
...@@ -214,13 +221,70 @@ class BuildConfigNode : public Node { ...@@ -214,13 +221,70 @@ class BuildConfigNode : public Node {
v->Visit("restricted_func", &restricted_func); v->Visit("restricted_func", &restricted_func);
v->Visit("detect_global_barrier", &detect_global_barrier); v->Visit("detect_global_barrier", &detect_global_barrier);
v->Visit("partition_const_loop", &partition_const_loop); v->Visit("partition_const_loop", &partition_const_loop);
v->Visit("dump_pass_ir", &dump_pass_ir);
} }
static constexpr const char* _type_key = "BuildConfig"; static constexpr const char* _type_key = "BuildConfig";
TVM_DECLARE_NODE_TYPE_INFO(BuildConfigNode, Node); TVM_DECLARE_NODE_TYPE_INFO(BuildConfigNode, Node);
}; };
TVM_DEFINE_NODE_REF(BuildConfig, BuildConfigNode); /*!
* \brief Container for build configuration options
*/
class BuildConfig : public ::tvm::NodeRef {
public:
BuildConfig() {}
explicit BuildConfig(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {}
const BuildConfigNode* operator->() const {
return static_cast<const BuildConfigNode*>(node_.get());
}
BuildConfigNode* operator->() {
return static_cast<BuildConfigNode*>(node_.get());
}
/*!
* \brief Push a new BuildConfig context onto the thread local stack.
* \param build_config The configuration to set as the current context.
*/
EXPORT static void EnterBuildConfigScope(const tvm::BuildConfig& build_config);
/*!
* \brief Pop a build config off the thread local context stack, restoring the previous
* configuration as the current context.
*/
EXPORT static void ExitBuildConfigScope();
/*!
* \brief Get the current BuildConfig context from thread local storage, or a default
* configuration if a BuildConfig scope has not been entered.
* \return The configuration that is the current context.
*/
EXPORT static tvm::BuildConfig Current();
using ContainerType = BuildConfigNode;
};
/*!
* \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the
* context stack when constructed, and pops it when destructed.
*/
struct BuildConfigContext {
/*!
* \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current
* context. When the BuildConfigContext is destructed, the previous context is restored.
* \param build_config The BuildConfig to set as the new current context.
*/
explicit BuildConfigContext(const tvm::BuildConfig& build_config) {
BuildConfig::EnterBuildConfigScope(build_config);
}
/*! \brief Destructor. Pops the context off the thread local stack. */
~BuildConfigContext() {
BuildConfig::ExitBuildConfigScope();
}
};
/*! /*!
* \brief Construct a BuildConfig containing a new BuildConfigNode * \brief Construct a BuildConfig containing a new BuildConfigNode
......
...@@ -8,8 +8,8 @@ import warnings ...@@ -8,8 +8,8 @@ import warnings
import types import types
from ._ffi.node import NodeBase, register_node from ._ffi.node import NodeBase, register_node
from ._ffi.base import _RUNTIME_ONLY
from . import api from . import api
from . import _api_internal
from . import tensor from . import tensor
from . import schedule from . import schedule
from . import expr from . import expr
...@@ -46,7 +46,8 @@ class DumpIR(object): ...@@ -46,7 +46,8 @@ class DumpIR(object):
retv = func(*args, **kwargs) retv = func(*args, **kwargs)
if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)): if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)):
return retv return retv
pname = str(self._pass_id) + "_" + func.func_name + "_ir.cc" fname = func.func_name if hasattr(func, 'func_name') else func.__name__
pname = str(self._pass_id) + "_" + fname + "_ir.cc"
with open(pname, "a") as f: with open(pname, "a") as f:
out = retv.body if isinstance(retv, container.LoweredFunc) else retv out = retv.body if isinstance(retv, container.LoweredFunc) else retv
f.write(str(out)) f.write(str(out))
...@@ -70,20 +71,20 @@ class DumpIR(object): ...@@ -70,20 +71,20 @@ class DumpIR(object):
self._recover_list.append(recover) self._recover_list.append(recover)
vset[k] = self.decorate(v) if isinstance(v, types.FunctionType) else v vset[k] = self.decorate(v) if isinstance(v, types.FunctionType) else v
def decorate_custompass(self): def decorate_custompass(self, custom_pass):
""" decorate add_lower_pass pass in BuildConfig""" """decorate given list of custom passes, and return decorated passes"""
cfg = BuildConfig.current custom_pass = custom_pass if custom_pass else []
self._old_custom_pass = cfg.add_lower_pass pass_list = []
custom_pass = cfg.add_lower_pass if cfg.add_lower_pass else [] for idx, x in enumerate(custom_pass):
pass_list = [(x[0], self.decorate(x[1])) for x in custom_pass] x[1].__name__ = "custom{}_phase{}".format(idx, x[0])
BuildConfig.current.add_lower_pass = pass_list pass_list += [(x[0], self.decorate(x[1]))]
return pass_list
def enter(self): def enter(self):
"""only decorate outermost nest""" """only decorate outermost nest"""
if DumpIR.scope_level > 0: if DumpIR.scope_level > 0:
return return
self.decorate_irpass() self.decorate_irpass()
self.decorate_custompass()
self._pass_id = 0 self._pass_id = 0
DumpIR.scope_level += 1 DumpIR.scope_level += 1
...@@ -95,7 +96,6 @@ class DumpIR(object): ...@@ -95,7 +96,6 @@ class DumpIR(object):
for f in self._recover_list: for f in self._recover_list:
f() f()
schedule.ScheduleOps = self._old_sgpass schedule.ScheduleOps = self._old_sgpass
BuildConfig.current.add_lower_pass = self._old_custom_pass
DumpIR.scope_level -= 1 DumpIR.scope_level -= 1
@register_node @register_node
...@@ -113,7 +113,6 @@ class BuildConfig(NodeBase): ...@@ -113,7 +113,6 @@ class BuildConfig(NodeBase):
is constructed. See _node_defaults for the fields. is constructed. See _node_defaults for the fields.
""" """
current = None
_node_defaults = { _node_defaults = {
"auto_unroll_max_step": 0, "auto_unroll_max_step": 0,
"auto_unroll_max_depth": 8, "auto_unroll_max_depth": 8,
...@@ -124,8 +123,10 @@ class BuildConfig(NodeBase): ...@@ -124,8 +123,10 @@ class BuildConfig(NodeBase):
"offset_factor": 0, "offset_factor": 0,
"data_alignment": -1, "data_alignment": -1,
"restricted_func": True, "restricted_func": True,
"double_buffer_split_loop": 1 "double_buffer_split_loop": 1,
"dump_pass_ir": False
} }
_dump_ir = DumpIR()
# pylint: disable=no-member # pylint: disable=no-member
def __init__(self, handle): def __init__(self, handle):
...@@ -138,24 +139,28 @@ class BuildConfig(NodeBase): ...@@ -138,24 +139,28 @@ class BuildConfig(NodeBase):
""" """
super(BuildConfig, self).__init__(handle) super(BuildConfig, self).__init__(handle)
self.handle = handle self.handle = handle
self._old_scope = None
self._dump_ir = DumpIR() @property
self.dump_pass_ir = False def add_lower_pass(self):
self.add_lower_pass = None size = _api_internal._BuildConfigGetAddLowerPassInfo(self)
result = []
for i in range(size):
phase = _api_internal._BuildConfigGetAddLowerPassInfo(self, i, True)
func = _api_internal._BuildConfigGetAddLowerPassInfo(self, i, False)
result += [(phase, func)]
return result
def __enter__(self): def __enter__(self):
# pylint: disable=protected-access # pylint: disable=protected-access
self._old_scope = BuildConfig.current _api_internal._EnterBuildConfigScope(self)
BuildConfig.current = self if self.dump_pass_ir:
if self.dump_pass_ir is True: BuildConfig._dump_ir.enter()
self._dump_ir.enter()
return self return self
def __exit__(self, ptype, value, trace): def __exit__(self, ptype, value, trace):
assert self._old_scope if self.dump_pass_ir:
if self.dump_pass_ir is True: BuildConfig._dump_ir.exit()
self._dump_ir.exit() _api_internal._ExitBuildConfigScope()
BuildConfig.current = self._old_scope
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name in BuildConfig._node_defaults: if name in BuildConfig._node_defaults:
...@@ -163,6 +168,9 @@ class BuildConfig(NodeBase): ...@@ -163,6 +168,9 @@ class BuildConfig(NodeBase):
"'%s' object cannot set attribute '%s'" % (str(type(self)), name)) "'%s' object cannot set attribute '%s'" % (str(type(self)), name))
return super(BuildConfig, self).__setattr__(name, value) return super(BuildConfig, self).__setattr__(name, value)
def current_build_config():
return _api_internal._GetCurrentBuildConfig()
def build_config(**kwargs): def build_config(**kwargs):
"""Configure the build behavior by setting config variables. """Configure the build behavior by setting config variables.
...@@ -221,14 +229,13 @@ def build_config(**kwargs): ...@@ -221,14 +229,13 @@ def build_config(**kwargs):
for k, v in BuildConfig._node_defaults.items()} for k, v in BuildConfig._node_defaults.items()}
config = make.node("BuildConfig", **node_args) config = make.node("BuildConfig", **node_args)
for k in kwargs: if "add_lower_pass" in kwargs:
if not k in node_args: add_lower_pass_args = []
setattr(config, k, kwargs[k]) for x in kwargs["add_lower_pass"]:
return config add_lower_pass_args += [x[0], x[1]]
_api_internal._BuildConfigSetAddLowerPass(config, *add_lower_pass_args)
if not _RUNTIME_ONLY: return config
# BuildConfig is not available in tvm_runtime
BuildConfig.current = build_config()
def get_binds(args, binds=None): def get_binds(args, binds=None):
"""Internal function to get binds and arg_list given arguments. """Internal function to get binds and arg_list given arguments.
...@@ -252,7 +259,7 @@ def get_binds(args, binds=None): ...@@ -252,7 +259,7 @@ def get_binds(args, binds=None):
The list of symbolic buffers of arguments. The list of symbolic buffers of arguments.
""" """
binds = {} if binds is None else binds.copy() binds = {} if binds is None else binds.copy()
cfg = BuildConfig.current cfg = current_build_config()
arg_list = [] arg_list = []
for x in args: for x in args:
if isinstance(x, tensor.Tensor): if isinstance(x, tensor.Tensor):
...@@ -309,8 +316,10 @@ def lower(sch, ...@@ -309,8 +316,10 @@ def lower(sch,
Then the Stmt before make api is returned. Then the Stmt before make api is returned.
""" """
binds, arg_list = get_binds(args, binds) binds, arg_list = get_binds(args, binds)
cfg = BuildConfig.current cfg = current_build_config()
add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else [] add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
if cfg.dump_pass_ir:
add_lower_pass = BuildConfig._dump_ir.decorate_custompass(add_lower_pass)
lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0] lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1] lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2] lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
...@@ -434,7 +443,7 @@ def build(sch, ...@@ -434,7 +443,7 @@ def build(sch,
"Direct host side access to device memory is detected in %s. " "Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name) "Did you forget to bind?" % func.name)
if func.func_type == container.LoweredFunc.MixedFunc: if func.func_type == container.LoweredFunc.MixedFunc:
if BuildConfig.current.detect_global_barrier: if current_build_config().detect_global_barrier:
func = ir_pass.ThreadSync(func, "global") func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared") func = ir_pass.ThreadSync(func, "shared")
warp_size = target.thread_warp_size warp_size = target.thread_warp_size
......
...@@ -6,7 +6,7 @@ from . import expr as _expr ...@@ -6,7 +6,7 @@ from . import expr as _expr
from . import stmt as _stmt from . import stmt as _stmt
from . import make as _make from . import make as _make
from . import tensor as _tensor from . import tensor as _tensor
from .build_module import BuildConfig from .build_module import current_build_config
from ._ffi.node import NodeBase, register_node from ._ffi.node import NodeBase, register_node
@register_node @register_node
...@@ -74,7 +74,7 @@ def decl_tensor_intrin(op, ...@@ -74,7 +74,7 @@ def decl_tensor_intrin(op,
if not isinstance(t.op, _tensor.PlaceholderOp): if not isinstance(t.op, _tensor.PlaceholderOp):
raise ValueError("Donot yet support composition op") raise ValueError("Donot yet support composition op")
cfg = BuildConfig.current cfg = current_build_config()
for t in tensors: for t in tensors:
buf = (binds[t] if t in binds else buf = (binds[t] if t in binds else
_api.decl_buffer(t.shape, t.dtype, t.op.name, _api.decl_buffer(t.shape, t.dtype, t.op.name,
......
...@@ -468,6 +468,41 @@ BuildConfig build_config() { ...@@ -468,6 +468,41 @@ BuildConfig build_config() {
return BuildConfig(std::make_shared<BuildConfigNode>()); return BuildConfig(std::make_shared<BuildConfigNode>());
} }
/*! \brief Entry to hold the BuildConfig context stack. */
struct TVMBuildConfigThreadLocalEntry {
/*! \brief The default build config if the stack is empty */
tvm::BuildConfig default_config;
/*! \brief The current build config context */
std::stack<tvm::BuildConfig> context_stack;
TVMBuildConfigThreadLocalEntry() :
default_config(build_config()) {
}
};
/*! \brief Thread local store to hold the BuildConfig context stack. */
typedef dmlc::ThreadLocalStore<TVMBuildConfigThreadLocalEntry> TVMBuildConfigThreadLocalStore;
void BuildConfig::EnterBuildConfigScope(const tvm::BuildConfig& build_config) {
TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
entry->context_stack.push(build_config);
}
void BuildConfig::ExitBuildConfigScope() {
TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
entry->context_stack.pop();
}
tvm::BuildConfig BuildConfig::Current() {
TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
if (entry->context_stack.size() > 0) {
return entry->context_stack.top();
}
return entry->default_config;
}
TVM_REGISTER_NODE_TYPE(BuildConfigNode); TVM_REGISTER_NODE_TYPE(BuildConfigNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
...@@ -482,7 +517,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -482,7 +517,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "unroll_explicit=" << op->unroll_explicit << ", "; p->stream << "unroll_explicit=" << op->unroll_explicit << ", ";
p->stream << "restricted_func=" << op->restricted_func << ", "; p->stream << "restricted_func=" << op->restricted_func << ", ";
p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", "; p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", ";
p->stream << "partition_const_loop=" << op->partition_const_loop; p->stream << "partition_const_loop=" << op->partition_const_loop << ", ";
p->stream << "dump_pass_ir=" << op->dump_pass_ir;
p->stream << ")"; p->stream << ")";
}); });
...@@ -571,6 +607,55 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const { ...@@ -571,6 +607,55 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const {
func.CallPacked(args, ret); func.CallPacked(args, ret);
} }
TVM_REGISTER_API("_GetCurrentBuildConfig")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BuildConfig::Current();
});
TVM_REGISTER_API("_EnterBuildConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig target = args[0];
BuildConfig::EnterBuildConfigScope(target);
});
TVM_REGISTER_API("_ExitBuildConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig::ExitBuildConfigScope();
});
TVM_REGISTER_API("_BuildConfigSetAddLowerPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig cfg = args[0];
std::vector< std::pair<int, PackedFunc> > add_lower_pass;
CHECK_EQ(args.size() % 2, 1);
for (int i = 1; i < args.size(); i += 2) {
add_lower_pass.push_back(std::make_pair(
args[i].operator int(),
args[i + 1].operator tvm::runtime::PackedFunc()));
}
cfg->add_lower_pass = add_lower_pass;
});
TVM_REGISTER_API("_BuildConfigGetAddLowerPassInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) {
// Return one of the following:
// * Size of add_lower_pass if num_args == 1
// * Phase index of pass if args are (config, index, true)
// * Function of pass if args are (config, index, false)
BuildConfig cfg = args[0];
if (args.num_args == 1) {
*ret = static_cast<int64_t>(cfg->add_lower_pass.size());
} else {
int index = args[1];
bool get_phase = args[2];
auto item = cfg->add_lower_pass[index];
if (get_phase) {
*ret = item.first;
} else {
*ret = item.second;
}
}
});
TVM_REGISTER_API("_GenericFuncCreate") TVM_REGISTER_API("_GenericFuncCreate")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
......
...@@ -38,6 +38,7 @@ if __name__ == "__main__": ...@@ -38,6 +38,7 @@ if __name__ == "__main__":
file_list = os.listdir('./') file_list = os.listdir('./')
cc_file = end_with('.cc') cc_file = end_with('.cc')
cc_file = filter(cc_file, file_list) cc_file = filter(cc_file, file_list)
cc_file = [f for f in cc_file]
assert len(cc_file) == 3 assert len(cc_file) == 3
for i in cc_file: for i in cc_file:
os.remove(i) os.remove(i)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment