Commit aaf7ff04 by alex-weaver Committed by Tianqi Chen

Move BuildConfig context stack to C++ (#1025)

parent 7b098c9a
......@@ -8,6 +8,7 @@
#include <string>
#include <vector>
#include <utility>
#include "./runtime/packed_func.h"
#include "./schedule_pass.h"
#include "./lowered_func.h"
......@@ -203,6 +204,12 @@ class BuildConfigNode : public Node {
/*! \brief Whether to partition const loop */
bool partition_const_loop = false;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
std::vector< std::pair<int, PackedFunc> > add_lower_pass;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
......@@ -214,13 +221,70 @@ class BuildConfigNode : public Node {
v->Visit("restricted_func", &restricted_func);
v->Visit("detect_global_barrier", &detect_global_barrier);
v->Visit("partition_const_loop", &partition_const_loop);
v->Visit("dump_pass_ir", &dump_pass_ir);
}
static constexpr const char* _type_key = "BuildConfig";
TVM_DECLARE_NODE_TYPE_INFO(BuildConfigNode, Node);
};
TVM_DEFINE_NODE_REF(BuildConfig, BuildConfigNode);
/*!
* \brief Container for build configuration options
*/
class BuildConfig : public ::tvm::NodeRef {
public:
BuildConfig() {}
explicit BuildConfig(std::shared_ptr<::tvm::Node> n) : NodeRef(n) {}
const BuildConfigNode* operator->() const {
return static_cast<const BuildConfigNode*>(node_.get());
}
BuildConfigNode* operator->() {
return static_cast<BuildConfigNode*>(node_.get());
}
/*!
* \brief Push a new BuildConfig context onto the thread local stack.
* \param build_config The configuration to set as the current context.
*/
EXPORT static void EnterBuildConfigScope(const tvm::BuildConfig& build_config);
/*!
* \brief Pop a build config off the thread local context stack, restoring the previous
* configuration as the current context.
*/
EXPORT static void ExitBuildConfigScope();
/*!
* \brief Get the current BuildConfig context from thread local storage, or a default
* configuration if a BuildConfig scope has not been entered.
* \return The configuration that is the current context.
*/
EXPORT static tvm::BuildConfig Current();
using ContainerType = BuildConfigNode;
};
/*!
* \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the
* context stack when constructed, and pops it when destructed.
*/
struct BuildConfigContext {
/*!
* \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current
* context. When the BuildConfigContext is destructed, the previous context is restored.
* \param build_config The BuildConfig to set as the new current context.
*/
explicit BuildConfigContext(const tvm::BuildConfig& build_config) {
BuildConfig::EnterBuildConfigScope(build_config);
}
/*! \brief Destructor. Pops the context off the thread local stack. */
~BuildConfigContext() {
BuildConfig::ExitBuildConfigScope();
}
};
/*!
* \brief Construct a BuildConfig containing a new BuildConfigNode
......
......@@ -8,8 +8,8 @@ import warnings
import types
from ._ffi.node import NodeBase, register_node
from ._ffi.base import _RUNTIME_ONLY
from . import api
from . import _api_internal
from . import tensor
from . import schedule
from . import expr
......@@ -46,7 +46,8 @@ class DumpIR(object):
retv = func(*args, **kwargs)
if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)):
return retv
pname = str(self._pass_id) + "_" + func.func_name + "_ir.cc"
fname = func.func_name if hasattr(func, 'func_name') else func.__name__
pname = str(self._pass_id) + "_" + fname + "_ir.cc"
with open(pname, "a") as f:
out = retv.body if isinstance(retv, container.LoweredFunc) else retv
f.write(str(out))
......@@ -70,20 +71,20 @@ class DumpIR(object):
self._recover_list.append(recover)
vset[k] = self.decorate(v) if isinstance(v, types.FunctionType) else v
def decorate_custompass(self):
""" decorate add_lower_pass pass in BuildConfig"""
cfg = BuildConfig.current
self._old_custom_pass = cfg.add_lower_pass
custom_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
pass_list = [(x[0], self.decorate(x[1])) for x in custom_pass]
BuildConfig.current.add_lower_pass = pass_list
def decorate_custompass(self, custom_pass):
"""decorate given list of custom passes, and return decorated passes"""
custom_pass = custom_pass if custom_pass else []
pass_list = []
for idx, x in enumerate(custom_pass):
x[1].__name__ = "custom{}_phase{}".format(idx, x[0])
pass_list += [(x[0], self.decorate(x[1]))]
return pass_list
def enter(self):
"""only decorate outermost nest"""
if DumpIR.scope_level > 0:
return
self.decorate_irpass()
self.decorate_custompass()
self._pass_id = 0
DumpIR.scope_level += 1
......@@ -95,7 +96,6 @@ class DumpIR(object):
for f in self._recover_list:
f()
schedule.ScheduleOps = self._old_sgpass
BuildConfig.current.add_lower_pass = self._old_custom_pass
DumpIR.scope_level -= 1
@register_node
......@@ -113,7 +113,6 @@ class BuildConfig(NodeBase):
is constructed. See _node_defaults for the fields.
"""
current = None
_node_defaults = {
"auto_unroll_max_step": 0,
"auto_unroll_max_depth": 8,
......@@ -124,8 +123,10 @@ class BuildConfig(NodeBase):
"offset_factor": 0,
"data_alignment": -1,
"restricted_func": True,
"double_buffer_split_loop": 1
"double_buffer_split_loop": 1,
"dump_pass_ir": False
}
_dump_ir = DumpIR()
# pylint: disable=no-member
def __init__(self, handle):
......@@ -138,24 +139,28 @@ class BuildConfig(NodeBase):
"""
super(BuildConfig, self).__init__(handle)
self.handle = handle
self._old_scope = None
self._dump_ir = DumpIR()
self.dump_pass_ir = False
self.add_lower_pass = None
@property
def add_lower_pass(self):
size = _api_internal._BuildConfigGetAddLowerPassInfo(self)
result = []
for i in range(size):
phase = _api_internal._BuildConfigGetAddLowerPassInfo(self, i, True)
func = _api_internal._BuildConfigGetAddLowerPassInfo(self, i, False)
result += [(phase, func)]
return result
def __enter__(self):
# pylint: disable=protected-access
self._old_scope = BuildConfig.current
BuildConfig.current = self
if self.dump_pass_ir is True:
self._dump_ir.enter()
_api_internal._EnterBuildConfigScope(self)
if self.dump_pass_ir:
BuildConfig._dump_ir.enter()
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope
if self.dump_pass_ir is True:
self._dump_ir.exit()
BuildConfig.current = self._old_scope
if self.dump_pass_ir:
BuildConfig._dump_ir.exit()
_api_internal._ExitBuildConfigScope()
def __setattr__(self, name, value):
if name in BuildConfig._node_defaults:
......@@ -163,6 +168,9 @@ class BuildConfig(NodeBase):
"'%s' object cannot set attribute '%s'" % (str(type(self)), name))
return super(BuildConfig, self).__setattr__(name, value)
def current_build_config():
return _api_internal._GetCurrentBuildConfig()
def build_config(**kwargs):
"""Configure the build behavior by setting config variables.
......@@ -221,14 +229,13 @@ def build_config(**kwargs):
for k, v in BuildConfig._node_defaults.items()}
config = make.node("BuildConfig", **node_args)
for k in kwargs:
if not k in node_args:
setattr(config, k, kwargs[k])
return config
if "add_lower_pass" in kwargs:
add_lower_pass_args = []
for x in kwargs["add_lower_pass"]:
add_lower_pass_args += [x[0], x[1]]
_api_internal._BuildConfigSetAddLowerPass(config, *add_lower_pass_args)
if not _RUNTIME_ONLY:
# BuildConfig is not available in tvm_runtime
BuildConfig.current = build_config()
return config
def get_binds(args, binds=None):
"""Internal function to get binds and arg_list given arguments.
......@@ -252,7 +259,7 @@ def get_binds(args, binds=None):
The list of symbolic buffers of arguments.
"""
binds = {} if binds is None else binds.copy()
cfg = BuildConfig.current
cfg = current_build_config()
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
......@@ -309,8 +316,10 @@ def lower(sch,
Then the Stmt before make api is returned.
"""
binds, arg_list = get_binds(args, binds)
cfg = BuildConfig.current
cfg = current_build_config()
add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
if cfg.dump_pass_ir:
add_lower_pass = BuildConfig._dump_ir.decorate_custompass(add_lower_pass)
lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
......@@ -434,7 +443,7 @@ def build(sch,
"Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name)
if func.func_type == container.LoweredFunc.MixedFunc:
if BuildConfig.current.detect_global_barrier:
if current_build_config().detect_global_barrier:
func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared")
warp_size = target.thread_warp_size
......
......@@ -6,7 +6,7 @@ from . import expr as _expr
from . import stmt as _stmt
from . import make as _make
from . import tensor as _tensor
from .build_module import BuildConfig
from .build_module import current_build_config
from ._ffi.node import NodeBase, register_node
@register_node
......@@ -74,7 +74,7 @@ def decl_tensor_intrin(op,
if not isinstance(t.op, _tensor.PlaceholderOp):
raise ValueError("Donot yet support composition op")
cfg = BuildConfig.current
cfg = current_build_config()
for t in tensors:
buf = (binds[t] if t in binds else
_api.decl_buffer(t.shape, t.dtype, t.op.name,
......
......@@ -468,6 +468,41 @@ BuildConfig build_config() {
return BuildConfig(std::make_shared<BuildConfigNode>());
}
/*! \brief Entry to hold the BuildConfig context stack. */
struct TVMBuildConfigThreadLocalEntry {
/*! \brief The default build config if the stack is empty */
tvm::BuildConfig default_config;
/*! \brief The current build config context */
std::stack<tvm::BuildConfig> context_stack;
TVMBuildConfigThreadLocalEntry() :
default_config(build_config()) {
}
};
/*! \brief Thread local store to hold the BuildConfig context stack. */
typedef dmlc::ThreadLocalStore<TVMBuildConfigThreadLocalEntry> TVMBuildConfigThreadLocalStore;
void BuildConfig::EnterBuildConfigScope(const tvm::BuildConfig& build_config) {
TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
entry->context_stack.push(build_config);
}
void BuildConfig::ExitBuildConfigScope() {
TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
entry->context_stack.pop();
}
tvm::BuildConfig BuildConfig::Current() {
TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
if (entry->context_stack.size() > 0) {
return entry->context_stack.top();
}
return entry->default_config;
}
TVM_REGISTER_NODE_TYPE(BuildConfigNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
......@@ -482,7 +517,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "unroll_explicit=" << op->unroll_explicit << ", ";
p->stream << "restricted_func=" << op->restricted_func << ", ";
p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", ";
p->stream << "partition_const_loop=" << op->partition_const_loop;
p->stream << "partition_const_loop=" << op->partition_const_loop << ", ";
p->stream << "dump_pass_ir=" << op->dump_pass_ir;
p->stream << ")";
});
......@@ -571,6 +607,55 @@ void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const {
func.CallPacked(args, ret);
}
TVM_REGISTER_API("_GetCurrentBuildConfig")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BuildConfig::Current();
});
TVM_REGISTER_API("_EnterBuildConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig target = args[0];
BuildConfig::EnterBuildConfigScope(target);
});
TVM_REGISTER_API("_ExitBuildConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig::ExitBuildConfigScope();
});
TVM_REGISTER_API("_BuildConfigSetAddLowerPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig cfg = args[0];
std::vector< std::pair<int, PackedFunc> > add_lower_pass;
CHECK_EQ(args.size() % 2, 1);
for (int i = 1; i < args.size(); i += 2) {
add_lower_pass.push_back(std::make_pair(
args[i].operator int(),
args[i + 1].operator tvm::runtime::PackedFunc()));
}
cfg->add_lower_pass = add_lower_pass;
});
TVM_REGISTER_API("_BuildConfigGetAddLowerPassInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) {
// Return one of the following:
// * Size of add_lower_pass if num_args == 1
// * Phase index of pass if args are (config, index, true)
// * Function of pass if args are (config, index, false)
BuildConfig cfg = args[0];
if (args.num_args == 1) {
*ret = static_cast<int64_t>(cfg->add_lower_pass.size());
} else {
int index = args[1];
bool get_phase = args[2];
auto item = cfg->add_lower_pass[index];
if (get_phase) {
*ret = item.first;
} else {
*ret = item.second;
}
}
});
TVM_REGISTER_API("_GenericFuncCreate")
.set_body([](TVMArgs args, TVMRetValue* ret) {
......
......@@ -38,6 +38,7 @@ if __name__ == "__main__":
file_list = os.listdir('./')
cc_file = end_with('.cc')
cc_file = filter(cc_file, file_list)
cc_file = [f for f in cc_file]
assert len(cc_file) == 3
for i in cc_file:
os.remove(i)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment