Commit c5395a1f by tqchen

Expose testcase as bound inference to python, now push toward the goal!

parent 54c18a6c
...@@ -239,7 +239,8 @@ def _init_function_module(root_namespace): ...@@ -239,7 +239,8 @@ def _init_function_module(root_namespace):
module_internal = sys.modules["%s._function_internal" % root_namespace] module_internal = sys.modules["%s._function_internal" % root_namespace]
namespace_match = { namespace_match = {
"_make_" : sys.modules["%s.make" % root_namespace], "_make_" : sys.modules["%s.make" % root_namespace],
"_pass_" : sys.modules["%s.ir_pass" % root_namespace] "_pass_" : sys.modules["%s.ir_pass" % root_namespace],
"_schedule_" : sys.modules["%s.schedule" % root_namespace]
} }
for name in op_names: for name in op_names:
......
...@@ -7,10 +7,10 @@ ...@@ -7,10 +7,10 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include "./c_api_registry.h" #include "./c_api_registry.h"
#include "../schedule/bound.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
using ArgStack = const std::vector<APIVariantValue>; using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue; using RetValue = APIVariantValue;
...@@ -21,6 +21,12 @@ using RetValue = APIVariantValue; ...@@ -21,6 +21,12 @@ using RetValue = APIVariantValue;
*ret = PassName(args.at(0)); \ *ret = PassName(args.at(0)); \
}) \ }) \
#define REGISTER_PASS2(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = PassName(args.at(0), args.at(1)); \
}) \
#define REGISTER_PASS4(PassName) \ #define REGISTER_PASS4(PassName) \
TVM_REGISTER_API(_pass_## PassName) \ TVM_REGISTER_API(_pass_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \ .set_body([](const ArgStack& args, RetValue *ret) { \
......
/*!
* Copyright (c) 2016 by Contributors
* Implementation of API functions related to schedule pass.
* \file c_api_lang.cc
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include "../schedule/bound.h"
#include "./c_api_registry.h"
namespace tvm {
namespace schedule {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
#define REGISTER_SCHEDULE_PASS1(PassName) \
TVM_REGISTER_API(_schedule_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = PassName(args.at(0)); \
}) \
REGISTER_SCHEDULE_PASS1(InferBound);
} // namespace schedule
} // namespace tvm
...@@ -143,7 +143,7 @@ void InferBound(const Schedule& sch, ...@@ -143,7 +143,7 @@ void InferBound(const Schedule& sch,
} }
std::unordered_map<IterVar, Range> InferBound(Schedule sch) { Map<IterVar, Range> InferBound(Schedule sch) {
return {}; return {};
} }
......
...@@ -19,7 +19,7 @@ namespace schedule { ...@@ -19,7 +19,7 @@ namespace schedule {
* \param sch The root schedule to infer all the bounds. * \param sch The root schedule to infer all the bounds.
* \return the result bound of the iteration Variable * \return the result bound of the iteration Variable
*/ */
std::unordered_map<IterVar, Range> InferBound(Schedule sch); Map<IterVar, Range> InferBound(Schedule sch);
} // namespace schedule } // namespace schedule
} // namespace tvm } // namespace tvm
......
import tvm
def test_bound_inference():
m = tvm.Var('m')
l = tvm.Var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j])
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3)
sA1 = tvm.Schedule(A1.op)
sA2 = tvm.Schedule(A2.op)
xo, xi = sA1.split(A1.op.dim_var[0], factor=8)
sA2.compute_at(sA1, xi)
bounds = tvm.schedule.InferBound(sA1)
assert isinstance(bounds, tvm.collections.Map)
print(bounds)
if __name__ == "__main__":
test_bound_inference()
...@@ -48,4 +48,3 @@ if __name__ == "__main__": ...@@ -48,4 +48,3 @@ if __name__ == "__main__":
test_schedule_create() test_schedule_create()
test_reorder() test_reorder()
test_tile() test_tile()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment