Commit 523b6a6b by alex-weaver Committed by Tianqi Chen

Convert BuildModule to use TVM node system (#879)

* Make python BuildConfig serializable/deserializable to/from string

* Make C++ BuildConfig serializable/deserializable to/from string

* Revert "Make python BuildConfig serializable/deserializable to/from string"

This reverts commit a5e1fb3ff63a161cc0d63475d2a32816cc4c3666.

* Revert "Make C++ BuildConfig serializable/deserializable to/from string"

This reverts commit ec0c2c54543050fe6f264d06eebff33dee70370b.

* Converted BuildConfig to use TVM node system

* Fix lint

* Fix lint

* Added code to set node attributes through the C API

* Fixed bug in build_config()

* Fix lint

* Fix lint

* Fix test errors

* Reduced scope of node __setattr__ to apply only to BuildConfig

* Fix lint

* Fix lint

* Changed python BuildConfig to be immutable, with values set once on construction.

* Fix lint

* Fix C++ test

* Fixed BuildConfig setting python-side args

* Fix lint

* Removed dependency on reflection.cc to construct BuildConfig (allow use in runtime library)

* Fix lint

* Revert "Fix lint"

This reverts commit 16ed6d7a1ca5e551b035bad46e8361ea487cd45b.

* Revert "Removed dependency on reflection.cc to construct BuildConfig (allow use in runtime library)"

This reverts commit 43817c97a2ee045791e0c031d962fa97636ce8f6.

* Avoid accessing BuildConfig when using runtime lib

* Fix missing import

* Fix error running under cython (root cause: node handle is not valid until after __init__ has returned, so cannot call __dir__ during __init__

* Fix error where BuildConfig._node_defaults was not copied in build_config()

* Fix lint

* Fix lint

* Fix lint

* Fix lint

* Add comments to python BuildConfig
parent 54d4fe4b
...@@ -85,10 +85,13 @@ EXPORT Target stackvm(); ...@@ -85,10 +85,13 @@ EXPORT Target stackvm();
} // namespace target } // namespace target
class BuildConfig;
/*! /*!
* \brief Container for build configuration options * \brief Container for build configuration options
*/ */
struct BuildConfig { class BuildConfigNode : public Node {
public:
/*! /*!
* \brief The data alignment to use when constructing buffers. If this is set to * \brief The data alignment to use when constructing buffers. If this is set to
* -1, then TVM's internal default will be used * -1, then TVM's internal default will be used
...@@ -126,10 +129,31 @@ struct BuildConfig { ...@@ -126,10 +129,31 @@ struct BuildConfig {
/*! \brief Whether to partition const loop */ /*! \brief Whether to partition const loop */
bool partition_const_loop = false; bool partition_const_loop = false;
BuildConfig() { void VisitAttrs(AttrVisitor* v) final {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
v->Visit("double_buffer_split_loop", &double_buffer_split_loop);
v->Visit("auto_unroll_max_step", &auto_unroll_max_step);
v->Visit("auto_unroll_max_depth", &auto_unroll_max_depth);
v->Visit("auto_unroll_max_extent", &auto_unroll_max_extent);
v->Visit("unroll_explicit", &unroll_explicit);
v->Visit("restricted_func", &restricted_func);
v->Visit("detect_global_barrier", &detect_global_barrier);
v->Visit("partition_const_loop", &partition_const_loop);
} }
static constexpr const char* _type_key = "BuildConfig";
TVM_DECLARE_NODE_TYPE_INFO(BuildConfigNode, Node);
}; };
TVM_DEFINE_NODE_REF(BuildConfig, BuildConfigNode);
/*!
* \brief Construct a BuildConfig containing a new BuildConfigNode
* \return The new BuildConfig
*/
EXPORT BuildConfig build_config();
/*! /*!
* \brief Build a LoweredFunc given a schedule, args and binds * \brief Build a LoweredFunc given a schedule, args and binds
* \param sch The schedule to lower. * \param sch The schedule to lower.
......
...@@ -6,7 +6,9 @@ LoweredFunc and compiled Module. ...@@ -6,7 +6,9 @@ LoweredFunc and compiled Module.
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import warnings import warnings
import types import types
import os
from ._ffi.node import NodeBase, register_node
from . import api from . import api
from . import tensor from . import tensor
from . import schedule from . import schedule
...@@ -18,6 +20,7 @@ from . import module ...@@ -18,6 +20,7 @@ from . import module
from . import codegen from . import codegen
from . import ndarray from . import ndarray
from . import target as _target from . import target as _target
from . import make
class DumpIR(object): class DumpIR(object):
"""Dump IR for each pass. """Dump IR for each pass.
...@@ -95,16 +98,23 @@ class DumpIR(object): ...@@ -95,16 +98,23 @@ class DumpIR(object):
BuildConfig.current.add_lower_pass = self._old_custom_pass BuildConfig.current.add_lower_pass = self._old_custom_pass
DumpIR.scope_level -= 1 DumpIR.scope_level -= 1
class BuildConfig(object): @register_node
class BuildConfig(NodeBase):
"""Configuration scope to set a build config option. """Configuration scope to set a build config option.
Parameters Note
---------- ----
kwargs This object is backed by node system in C++, with arguments that can be
Keyword arguments of configurations to set. exchanged between python and C++.
Do not construct directly, use build_config instead.
The fields that are backed by the C++ node are immutable once an instance
is constructed. See _node_defaults for the fields.
""" """
current = None current = None
defaults = { _node_defaults = {
"auto_unroll_max_step": 0, "auto_unroll_max_step": 0,
"auto_unroll_max_depth": 8, "auto_unroll_max_depth": 8,
"auto_unroll_max_extent": 0, "auto_unroll_max_extent": 0,
...@@ -114,30 +124,28 @@ class BuildConfig(object): ...@@ -114,30 +124,28 @@ class BuildConfig(object):
"offset_factor": 0, "offset_factor": 0,
"data_alignment": -1, "data_alignment": -1,
"restricted_func": True, "restricted_func": True,
"double_buffer_split_loop": 1, "double_buffer_split_loop": 1
"add_lower_pass": None,
"dump_pass_ir": False
} }
def __init__(self, **kwargs):
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
super(BuildConfig, self).__init__(handle)
self.handle = handle
self._old_scope = None self._old_scope = None
self._dump_ir = DumpIR() self._dump_ir = DumpIR()
for k, _ in kwargs.items(): self.dump_pass_ir = False
if k not in BuildConfig.defaults: self.add_lower_pass = None
raise ValueError(
"invalid argument %s, candidates are %s" % (k, BuildConfig.defaults.keys()))
self._attr = kwargs
def __getattr__(self, name):
if name not in self._attr:
return BuildConfig.defaults[name]
return self._attr[name]
def __enter__(self): def __enter__(self):
# pylint: disable=protected-access # pylint: disable=protected-access
self._old_scope = BuildConfig.current self._old_scope = BuildConfig.current
attr = BuildConfig.current._attr.copy()
attr.update(self._attr)
self._attr = attr
BuildConfig.current = self BuildConfig.current = self
if self.dump_pass_ir is True: if self.dump_pass_ir is True:
self._dump_ir.enter() self._dump_ir.enter()
...@@ -149,8 +157,11 @@ class BuildConfig(object): ...@@ -149,8 +157,11 @@ class BuildConfig(object):
self._dump_ir.exit() self._dump_ir.exit()
BuildConfig.current = self._old_scope BuildConfig.current = self._old_scope
def __setattr__(self, name, value):
BuildConfig.current = BuildConfig() if name in BuildConfig._node_defaults:
raise AttributeError(
"'%s' object cannot set attribute '%s'" % (str(type(self)), name))
return super(BuildConfig, self).__setattr__(name, value)
def build_config(**kwargs): def build_config(**kwargs):
"""Configure the build behavior by setting config variables. """Configure the build behavior by setting config variables.
...@@ -206,8 +217,18 @@ def build_config(**kwargs): ...@@ -206,8 +217,18 @@ def build_config(**kwargs):
config: BuildConfig config: BuildConfig
The build configuration The build configuration
""" """
return BuildConfig(**kwargs) node_args = {k: v if k not in kwargs else kwargs[k]
for k, v in BuildConfig._node_defaults.items()}
config = make.node("BuildConfig", **node_args)
for k in kwargs:
if not k in node_args:
setattr(config, k, kwargs[k])
return config
if not os.environ.get("TVM_USE_RUNTIME_LIB", False):
# BuildConfig is not available in tvm_runtime
BuildConfig.current = build_config()
def get_binds(args, binds=None): def get_binds(args, binds=None):
"""Internal function to get binds and arg_list given arguments. """Internal function to get binds and arg_list given arguments.
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <tvm/buffer.h> #include <tvm/buffer.h>
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include <tvm/api_registry.h> #include <tvm/api_registry.h>
#include <tvm/build_module.h>
namespace tvm { namespace tvm {
......
...@@ -179,7 +179,7 @@ void GetBinds(const Array<Tensor>& args, ...@@ -179,7 +179,7 @@ void GetBinds(const Array<Tensor>& args,
for (const auto &x : args) { for (const auto &x : args) {
if (out_binds->find(x) == out_binds->end()) { if (out_binds->find(x) == out_binds->end()) {
auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name, auto buf = BufferWithOffsetAlignment(x->shape, x->dtype, x->op->name,
config.data_alignment, config.offset_factor); config->data_alignment, config->offset_factor);
out_binds->Set(x, buf); out_binds->Set(x, buf);
out_arg_list->push_back(buf); out_arg_list->push_back(buf);
} else { } else {
...@@ -218,14 +218,14 @@ Stmt BuildStmt(Schedule sch, ...@@ -218,14 +218,14 @@ Stmt BuildStmt(Schedule sch,
stmt = ir::StorageFlatten(stmt, out_binds, 64); stmt = ir::StorageFlatten(stmt, out_binds, 64);
stmt = ir::CanonicalSimplify(stmt); stmt = ir::CanonicalSimplify(stmt);
if (loop_partition) { if (loop_partition) {
stmt = ir::LoopPartition(stmt, config.partition_const_loop); stmt = ir::LoopPartition(stmt, config->partition_const_loop);
} }
stmt = ir::VectorizeLoop(stmt); stmt = ir::VectorizeLoop(stmt);
stmt = ir::InjectVirtualThread(stmt); stmt = ir::InjectVirtualThread(stmt);
stmt = ir::InjectDoubleBuffer(stmt, config.double_buffer_split_loop); stmt = ir::InjectDoubleBuffer(stmt, config->double_buffer_split_loop);
stmt = ir::StorageRewrite(stmt); stmt = ir::StorageRewrite(stmt);
stmt = ir::UnrollLoop(stmt, config.auto_unroll_max_step, config.auto_unroll_max_depth, stmt = ir::UnrollLoop(stmt, config->auto_unroll_max_step, config->auto_unroll_max_depth,
config.auto_unroll_max_extent, config.unroll_explicit); config->auto_unroll_max_extent, config->unroll_explicit);
// Phase 2 // Phase 2
stmt = ir::Simplify(stmt); stmt = ir::Simplify(stmt);
...@@ -243,7 +243,7 @@ Array<LoweredFunc> lower(Schedule sch, ...@@ -243,7 +243,7 @@ Array<LoweredFunc> lower(Schedule sch,
const BuildConfig& config) { const BuildConfig& config) {
Array<NodeRef> out_arg_list; Array<NodeRef> out_arg_list;
auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config); auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config.restricted_func) }); return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) });
} }
runtime::Module build(const Array<LoweredFunc>& funcs, runtime::Module build(const Array<LoweredFunc>& funcs,
...@@ -266,7 +266,7 @@ runtime::Module build(const Array<LoweredFunc>& funcs, ...@@ -266,7 +266,7 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
for (const auto &x : funcs) { for (const auto &x : funcs) {
if (x->func_type == kMixedFunc) { if (x->func_type == kMixedFunc) {
auto func = x; auto func = x;
if (config.detect_global_barrier) { if (config->detect_global_barrier) {
func = ir::ThreadSync(func, "global"); func = ir::ThreadSync(func, "global");
} }
...@@ -321,4 +321,27 @@ runtime::Module build(const Array<LoweredFunc>& funcs, ...@@ -321,4 +321,27 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
return mhost; return mhost;
} }
BuildConfig build_config() {
return BuildConfig(std::make_shared<BuildConfigNode>());
}
TVM_REGISTER_NODE_TYPE(BuildConfigNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<BuildConfigNode>([](const BuildConfigNode *op, IRPrinter *p) {
p->stream << "build_config(";
p->stream << "data_alignment=" << op->data_alignment << ", ";
p->stream << "offset_factor=" << op->offset_factor << ", ";
p->stream << "double_buffer_split_loop=" << op->double_buffer_split_loop << ", ";
p->stream << "auto_unroll_max_step=" << op->auto_unroll_max_step << ", ";
p->stream << "auto_unroll_max_depth=" << op->auto_unroll_max_depth << ", ";
p->stream << "auto_unroll_max_extent=" << op->auto_unroll_max_extent << ", ";
p->stream << "unroll_explicit=" << op->unroll_explicit << ", ";
p->stream << "restricted_func=" << op->restricted_func << ", ";
p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", ";
p->stream << "partition_const_loop=" << op->partition_const_loop;
p->stream << ")";
});
} // namespace tvm } // namespace tvm
...@@ -27,7 +27,7 @@ TEST(BuildModule, Basic) { ...@@ -27,7 +27,7 @@ TEST(BuildModule, Basic) {
auto args = Array<Tensor>({ A, B, C }); auto args = Array<Tensor>({ A, B, C });
std::unordered_map<Tensor, Buffer> binds; std::unordered_map<Tensor, Buffer> binds;
BuildConfig config; auto config = build_config();
auto target = target::llvm(); auto target = target::llvm();
auto lowered = lower(s, args, "func", binds, config); auto lowered = lower(s, args, "func", binds, config);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment