Commit 9ce7f0a3 by tqchen

Check in inline and test

parent 2fafa935
......@@ -18,6 +18,7 @@ namespace ir {
using Halide::Internal::ExprNode;
using Halide::Internal::IRNodeType;
using Halide::Internal::ForType;
/*! \brief Reduction operator operator */
struct Reduce : public ExprNode<Reduce> {
......
/*!
* Copyright (c) 2016 by Contributors
* \file ir_pass.h
* \brief Collection of IR pass functions and visit functions
* \brief Collection of IR pass functions
*
* All the pass functions in this file are for Stmt,
* We can use PassFunction(Evaluate(expr)) to apply it to Expr
*/
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
......@@ -22,14 +25,14 @@ namespace ir {
* \return Whether IR is in SSA form.
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
bool VerifySSA(const IRNodeRef& ir);
bool VerifySSA(const Stmt& ir);
/*!
* \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted.
* \return The converted form.
*/
Stmt ConvertSSA(const Stmt& stmt);
Stmt ConvertSSA(Stmt stmt);
/*!
* \brief inline all calls of f in stmt.
......@@ -42,8 +45,10 @@ Stmt ConvertSSA(const Stmt& stmt);
*
* \note All the passes in this file uses SSA form and outputs SSA form.
*/
Stmt InlineSSA(FunctionRef f, const std::vector<Var>& args, Expr body, Stmt stmt);
Stmt Inline(FunctionRef f,
Array<Var> args,
Expr body,
Stmt stmt);
} // namespace ir
} // namespace tvm
......
......@@ -6,6 +6,7 @@ from . import tensor as tensor
from . import expr
from . import stmt
from . import make
from . import ir_pass
from . import collections
from . import schedule
......
......@@ -224,21 +224,19 @@ def _init_function_module(root_namespace):
module_obj = sys.modules["%s.function" % root_namespace]
module_internal = sys.modules["%s._function_internal" % root_namespace]
module_make = sys.modules["%s.make" % root_namespace]
namespace_match = {
"_make_" : sys.modules["%s.make" % root_namespace],
"_pass_" : sys.modules["%s.ir_pass" % root_namespace]
}
for name in op_names:
hdl = FunctionHandle()
check_call(_LIB.TVMGetFunctionHandle(c_str(name), ctypes.byref(hdl)))
if name.startswith("_make_"):
fname = name[6:]
else:
fname = name
fname = name
target_module = module_internal if name.startswith('_') else module_obj
for k, v in namespace_match.items():
if name.startswith(k):
fname = name[len(k):]
target_module = v
function = _make_function(hdl, fname)
if name.startswith("_make_"):
setattr(module_make, function.__name__, function)
elif function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)
setattr(target_module, function.__name__, function)
"""Namespace of IR pass functions"""
......@@ -164,14 +164,16 @@ int TVMPushStack(ArgVariant arg,
API_BEGIN();
ret->arg_stack.resize(ret->arg_stack.size() + 1);
APIVariantValue& v = ret->arg_stack.back();
v.type_id = static_cast<ArgVariantID>(type_id);
if (type_id == kStr) {
v = arg.v_str;
v.str = arg.v_str;
} else if (type_id == kNodeHandle) {
v.sptr = *static_cast<TVMAPINode*>(arg.v_handle);
} else {
v.v_union = arg;
}
API_END_HANDLE_ERROR(ret->Clear());
}
......
......@@ -9,9 +9,7 @@
#include "./c_api_registry.h"
namespace tvm {
using namespace tvm::ir;
using namespace Halide::Internal;
namespace ir {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
......@@ -135,4 +133,5 @@ REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);
} // namespace ir
} // namespace tvm
......@@ -19,7 +19,6 @@ using RetValue = APIVariantValue;
TVM_REGISTER_API(_const)
.set_body([](const ArgStack& args, RetValue *ret) {
using Halide::Internal::make_const;
if (args.at(0).type_id == kLong) {
*ret = make_const(args.at(1), args.at(0).operator int64_t());
} else if (args.at(0).type_id == kDouble) {
......
/*!
* Copyright (c) 2016 by Contributors
* Exposre of pass functions.
* \file c_api_pass.cc
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "./c_api_registry.h"
namespace tvm {
namespace ir {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
// make from two arguments
#define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = PassName(args.at(0)); \
}) \
#define REGISTER_PASS4(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = PassName(args.at(0), args.at(1), args.at(2), args.at(3)); \
}) \
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
} // namespace ir
} // namespace tvm
......@@ -43,6 +43,17 @@ inline Type String2Type(std::string s) {
return Type(code, bits, lanes);
}
inline const char* TypeId2Str(ArgVariantID type_id) {
switch (type_id) {
case kNull: return "Null";
case kLong: return "Long";
case kDouble: return "Double";
case kStr: return "Str";
case kNodeHandle: return "NodeHandle";
default: LOG(FATAL) << "unknown type_id=" << type_id; return "";
}
}
/*! \brief Variant container for API calls */
class APIVariantValue {
public:
......@@ -74,6 +85,11 @@ class APIVariantValue {
v_union.v_long = value;
return *this;
}
inline APIVariantValue& operator=(bool value) {
type_id = kLong;
v_union.v_long = value;
return *this;
}
inline APIVariantValue& operator=(std::string value) {
type_id = kStr;
str = std::move(value);
......@@ -130,11 +146,13 @@ class APIVariantValue {
return v_union.v_long;
}
inline operator bool() const {
CHECK_EQ(type_id, kLong);
CHECK_EQ(type_id, kLong)
<< "expect boolean(int) but get " << TypeId2Str(type_id);
return v_union.v_long != 0;
}
inline operator std::string() const {
CHECK_EQ(type_id, kStr);
CHECK_EQ(type_id, kStr)
<< "expect Str but get " << TypeId2Str(type_id);
return str;
}
inline operator Type() const {
......
......@@ -21,8 +21,9 @@ Expr Tensor::operator()(Array<Expr> indices) const {
CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size();
return Call::make(
auto n Call::make(
(*this)->dtype, (*this)->name, indices, Call::Halide, *this);
return n;
}
......
/*!
* Copyright (c) 2016 by Contributors
* \file inline.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
namespace tvm {
namespace ir {
namespace {
// inliner to inline a function
// the result may not be SSA,
// ConvertSSA need to be applied after this pass
class IRInline : public IRMutator {
public:
IRInline(FunctionRef f, Array<Var> args, Expr body)
: f_(f), args_(args), body_(body) {}
Expr Mutate(Expr expr) final {
const Call* call = expr.as<Call>();
if (call != nullptr && call->func == f_) {
return InlineCall(call);
} else {
return IRMutator::Mutate(expr);
}
}
Stmt Mutate(Stmt stmt) final {
return IRMutator::Mutate(stmt);
}
private:
FunctionRef f_;
Array<Var> args_;
Expr body_;
Expr InlineCall(const Call* op) {
Expr expr = body_;
CHECK_EQ(args_.size(), op->args.size())
<< op->args.size() << " vs " << args_.size();
for (size_t i = 0; i < args_.size(); ++i) {
expr = Let::make(args_[i], op->args[i], expr);
}
return expr;
}
};
} // namespace
Stmt Inline(FunctionRef f,
Array<Var> args,
Expr body,
Stmt stmt) {
return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
}
} // namespace ir
} // namespace tvm
......@@ -156,13 +156,13 @@ class IRConvertSSA : public IRMutator {
} // namespace
bool VerifySSA(const IRNodeRef& ir) {
bool VerifySSA(const Stmt& ir) {
IRVerifySSA v;
v.Visit(ir);
return v.is_ssa;
}
Stmt ConvertSSA(const Stmt& stmt) {
Stmt ConvertSSA(Stmt stmt) {
return IRConvertSSA().Mutate(stmt);
}
......
......@@ -10,9 +10,9 @@ TEST(IRSSA, Convert) {
Var x("x"), y;
Expr let = Let::make(x, 1, x + 1);
auto z = let + let;
auto z = Evaluate::make(let + let);
CHECK(!ir::VerifySSA(z));
auto z_ssa = ir::ConvertSSA(Evaluate::make(z));
auto z_ssa = ir::ConvertSSA(z);
CHECK(ir::VerifySSA(z_ssa));
}
......@@ -20,7 +20,7 @@ TEST(IRSSA, Basic) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
auto z = Evaluate::make(x + y);
CHECK(ir::VerifySSA(z));
}
......
import tvm
def test_inline():
m = tvm.Var('m')
A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i,: A(i) + 10, name='T')
X = T(100)
stmt = tvm.make.Evaluate(T(10) + 11 * T(100))
stmt = tvm.ir_pass.Inline(
T, T.source_op.iter_var, T.source_op.body, stmt)
print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt))
if __name__ == "__main__":
test_inline()
import tvm
def test_verify_ssa():
x = tvm.Var('x')
y = tvm.Var()
z = tvm.make.Evaluate(x + y)
assert(tvm.ir_pass.VerifySSA(z))
def test_convert_ssa():
x = tvm.Var('x')
y = tvm.Var()
let = tvm.make.Let(x, 1, x + 1)
z = tvm.make.Evaluate(let + let)
assert(not tvm.ir_pass.VerifySSA(z))
z_ssa = tvm.ir_pass.ConvertSSA(z)
assert(tvm.ir_pass.VerifySSA(z_ssa))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment