Commit 9ce7f0a3 by tqchen

Check in inline and test

parent 2fafa935
...@@ -18,6 +18,7 @@ namespace ir { ...@@ -18,6 +18,7 @@ namespace ir {
using Halide::Internal::ExprNode; using Halide::Internal::ExprNode;
using Halide::Internal::IRNodeType; using Halide::Internal::IRNodeType;
using Halide::Internal::ForType;
/*! \brief Reduction operator operator */ /*! \brief Reduction operator operator */
struct Reduce : public ExprNode<Reduce> { struct Reduce : public ExprNode<Reduce> {
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file ir_pass.h * \file ir_pass.h
* \brief Collection of IR pass functions and visit functions * \brief Collection of IR pass functions
*
* All the pass functions in this file are for Stmt,
* We can use PassFunction(Evaluate(expr)) to apply it to Expr
*/ */
#ifndef TVM_IR_PASS_H_ #ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_ #define TVM_IR_PASS_H_
...@@ -22,14 +25,14 @@ namespace ir { ...@@ -22,14 +25,14 @@ namespace ir {
* \return Whether IR is in SSA form. * \return Whether IR is in SSA form.
* \note All the passes in this file uses SSA form and outputs SSA form. * \note All the passes in this file uses SSA form and outputs SSA form.
*/ */
bool VerifySSA(const IRNodeRef& ir); bool VerifySSA(const Stmt& ir);
/*! /*!
* \brief Convert a IR node to be SSA form. * \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted. * \param stmt The source statement to be converted.
* \return The converted form. * \return The converted form.
*/ */
Stmt ConvertSSA(const Stmt& stmt); Stmt ConvertSSA(Stmt stmt);
/*! /*!
* \brief inline all calls of f in stmt. * \brief inline all calls of f in stmt.
...@@ -42,8 +45,10 @@ Stmt ConvertSSA(const Stmt& stmt); ...@@ -42,8 +45,10 @@ Stmt ConvertSSA(const Stmt& stmt);
* *
* \note All the passes in this file uses SSA form and outputs SSA form. * \note All the passes in this file uses SSA form and outputs SSA form.
*/ */
Stmt InlineSSA(FunctionRef f, const std::vector<Var>& args, Expr body, Stmt stmt); Stmt Inline(FunctionRef f,
Array<Var> args,
Expr body,
Stmt stmt);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
......
...@@ -6,6 +6,7 @@ from . import tensor as tensor ...@@ -6,6 +6,7 @@ from . import tensor as tensor
from . import expr from . import expr
from . import stmt from . import stmt
from . import make from . import make
from . import ir_pass
from . import collections from . import collections
from . import schedule from . import schedule
......
...@@ -224,21 +224,19 @@ def _init_function_module(root_namespace): ...@@ -224,21 +224,19 @@ def _init_function_module(root_namespace):
module_obj = sys.modules["%s.function" % root_namespace] module_obj = sys.modules["%s.function" % root_namespace]
module_internal = sys.modules["%s._function_internal" % root_namespace] module_internal = sys.modules["%s._function_internal" % root_namespace]
module_make = sys.modules["%s.make" % root_namespace] namespace_match = {
"_make_" : sys.modules["%s.make" % root_namespace],
"_pass_" : sys.modules["%s.ir_pass" % root_namespace]
}
for name in op_names: for name in op_names:
hdl = FunctionHandle() hdl = FunctionHandle()
check_call(_LIB.TVMGetFunctionHandle(c_str(name), ctypes.byref(hdl))) check_call(_LIB.TVMGetFunctionHandle(c_str(name), ctypes.byref(hdl)))
if name.startswith("_make_"): fname = name
fname = name[6:] target_module = module_internal if name.startswith('_') else module_obj
else: for k, v in namespace_match.items():
fname = name if name.startswith(k):
fname = name[len(k):]
target_module = v
function = _make_function(hdl, fname) function = _make_function(hdl, fname)
setattr(target_module, function.__name__, function)
if name.startswith("_make_"):
setattr(module_make, function.__name__, function)
elif function.__name__.startswith('_'):
setattr(module_internal, function.__name__, function)
else:
setattr(module_obj, function.__name__, function)
"""Namespace of IR pass functions"""
...@@ -164,14 +164,16 @@ int TVMPushStack(ArgVariant arg, ...@@ -164,14 +164,16 @@ int TVMPushStack(ArgVariant arg,
API_BEGIN(); API_BEGIN();
ret->arg_stack.resize(ret->arg_stack.size() + 1); ret->arg_stack.resize(ret->arg_stack.size() + 1);
APIVariantValue& v = ret->arg_stack.back(); APIVariantValue& v = ret->arg_stack.back();
v.type_id = static_cast<ArgVariantID>(type_id); v.type_id = static_cast<ArgVariantID>(type_id);
if (type_id == kStr) { if (type_id == kStr) {
v = arg.v_str; v.str = arg.v_str;
} else if (type_id == kNodeHandle) { } else if (type_id == kNodeHandle) {
v.sptr = *static_cast<TVMAPINode*>(arg.v_handle); v.sptr = *static_cast<TVMAPINode*>(arg.v_handle);
} else { } else {
v.v_union = arg; v.v_union = arg;
} }
API_END_HANDLE_ERROR(ret->Clear()); API_END_HANDLE_ERROR(ret->Clear());
} }
......
...@@ -9,9 +9,7 @@ ...@@ -9,9 +9,7 @@
#include "./c_api_registry.h" #include "./c_api_registry.h"
namespace tvm { namespace tvm {
namespace ir {
using namespace tvm::ir;
using namespace Halide::Internal;
using ArgStack = const std::vector<APIVariantValue>; using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue; using RetValue = APIVariantValue;
...@@ -135,4 +133,5 @@ REGISTER_MAKE2(Block); ...@@ -135,4 +133,5 @@ REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse); REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate); REGISTER_MAKE1(Evaluate);
} // namespace ir
} // namespace tvm } // namespace tvm
...@@ -19,7 +19,6 @@ using RetValue = APIVariantValue; ...@@ -19,7 +19,6 @@ using RetValue = APIVariantValue;
TVM_REGISTER_API(_const) TVM_REGISTER_API(_const)
.set_body([](const ArgStack& args, RetValue *ret) { .set_body([](const ArgStack& args, RetValue *ret) {
using Halide::Internal::make_const; using Halide::Internal::make_const;
if (args.at(0).type_id == kLong) { if (args.at(0).type_id == kLong) {
*ret = make_const(args.at(1), args.at(0).operator int64_t()); *ret = make_const(args.at(1), args.at(0).operator int64_t());
} else if (args.at(0).type_id == kDouble) { } else if (args.at(0).type_id == kDouble) {
......
/*!
* Copyright (c) 2016 by Contributors
* Exposre of pass functions.
* \file c_api_pass.cc
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "./c_api_registry.h"
namespace tvm {
namespace ir {
using ArgStack = const std::vector<APIVariantValue>;
using RetValue = APIVariantValue;
// make from two arguments
#define REGISTER_PASS1(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = PassName(args.at(0)); \
}) \
#define REGISTER_PASS4(PassName) \
TVM_REGISTER_API(_pass_## PassName) \
.set_body([](const ArgStack& args, RetValue *ret) { \
*ret = PassName(args.at(0), args.at(1), args.at(2), args.at(3)); \
}) \
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
} // namespace ir
} // namespace tvm
...@@ -43,6 +43,17 @@ inline Type String2Type(std::string s) { ...@@ -43,6 +43,17 @@ inline Type String2Type(std::string s) {
return Type(code, bits, lanes); return Type(code, bits, lanes);
} }
inline const char* TypeId2Str(ArgVariantID type_id) {
switch (type_id) {
case kNull: return "Null";
case kLong: return "Long";
case kDouble: return "Double";
case kStr: return "Str";
case kNodeHandle: return "NodeHandle";
default: LOG(FATAL) << "unknown type_id=" << type_id; return "";
}
}
/*! \brief Variant container for API calls */ /*! \brief Variant container for API calls */
class APIVariantValue { class APIVariantValue {
public: public:
...@@ -74,6 +85,11 @@ class APIVariantValue { ...@@ -74,6 +85,11 @@ class APIVariantValue {
v_union.v_long = value; v_union.v_long = value;
return *this; return *this;
} }
inline APIVariantValue& operator=(bool value) {
type_id = kLong;
v_union.v_long = value;
return *this;
}
inline APIVariantValue& operator=(std::string value) { inline APIVariantValue& operator=(std::string value) {
type_id = kStr; type_id = kStr;
str = std::move(value); str = std::move(value);
...@@ -130,11 +146,13 @@ class APIVariantValue { ...@@ -130,11 +146,13 @@ class APIVariantValue {
return v_union.v_long; return v_union.v_long;
} }
inline operator bool() const { inline operator bool() const {
CHECK_EQ(type_id, kLong); CHECK_EQ(type_id, kLong)
<< "expect boolean(int) but get " << TypeId2Str(type_id);
return v_union.v_long != 0; return v_union.v_long != 0;
} }
inline operator std::string() const { inline operator std::string() const {
CHECK_EQ(type_id, kStr); CHECK_EQ(type_id, kStr)
<< "expect Str but get " << TypeId2Str(type_id);
return str; return str;
} }
inline operator Type() const { inline operator Type() const {
......
...@@ -21,8 +21,9 @@ Expr Tensor::operator()(Array<Expr> indices) const { ...@@ -21,8 +21,9 @@ Expr Tensor::operator()(Array<Expr> indices) const {
CHECK_EQ(ndim(), indices.size()) CHECK_EQ(ndim(), indices.size())
<< "Tensor dimension mismatch in read" << "Tensor dimension mismatch in read"
<< "ndim = " << ndim() << ", indices.size=" << indices.size(); << "ndim = " << ndim() << ", indices.size=" << indices.size();
return Call::make( auto n Call::make(
(*this)->dtype, (*this)->name, indices, Call::Halide, *this); (*this)->dtype, (*this)->name, indices, Call::Halide, *this);
return n;
} }
......
/*!
* Copyright (c) 2016 by Contributors
* \file inline.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
namespace tvm {
namespace ir {
namespace {
// inliner to inline a function
// the result may not be SSA,
// ConvertSSA need to be applied after this pass
class IRInline : public IRMutator {
public:
IRInline(FunctionRef f, Array<Var> args, Expr body)
: f_(f), args_(args), body_(body) {}
Expr Mutate(Expr expr) final {
const Call* call = expr.as<Call>();
if (call != nullptr && call->func == f_) {
return InlineCall(call);
} else {
return IRMutator::Mutate(expr);
}
}
Stmt Mutate(Stmt stmt) final {
return IRMutator::Mutate(stmt);
}
private:
FunctionRef f_;
Array<Var> args_;
Expr body_;
Expr InlineCall(const Call* op) {
Expr expr = body_;
CHECK_EQ(args_.size(), op->args.size())
<< op->args.size() << " vs " << args_.size();
for (size_t i = 0; i < args_.size(); ++i) {
expr = Let::make(args_[i], op->args[i], expr);
}
return expr;
}
};
} // namespace
Stmt Inline(FunctionRef f,
Array<Var> args,
Expr body,
Stmt stmt) {
return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
}
} // namespace ir
} // namespace tvm
...@@ -156,13 +156,13 @@ class IRConvertSSA : public IRMutator { ...@@ -156,13 +156,13 @@ class IRConvertSSA : public IRMutator {
} // namespace } // namespace
bool VerifySSA(const IRNodeRef& ir) { bool VerifySSA(const Stmt& ir) {
IRVerifySSA v; IRVerifySSA v;
v.Visit(ir); v.Visit(ir);
return v.is_ssa; return v.is_ssa;
} }
Stmt ConvertSSA(const Stmt& stmt) { Stmt ConvertSSA(Stmt stmt) {
return IRConvertSSA().Mutate(stmt); return IRConvertSSA().Mutate(stmt);
} }
......
...@@ -10,9 +10,9 @@ TEST(IRSSA, Convert) { ...@@ -10,9 +10,9 @@ TEST(IRSSA, Convert) {
Var x("x"), y; Var x("x"), y;
Expr let = Let::make(x, 1, x + 1); Expr let = Let::make(x, 1, x + 1);
auto z = let + let; auto z = Evaluate::make(let + let);
CHECK(!ir::VerifySSA(z)); CHECK(!ir::VerifySSA(z));
auto z_ssa = ir::ConvertSSA(Evaluate::make(z)); auto z_ssa = ir::ConvertSSA(z);
CHECK(ir::VerifySSA(z_ssa)); CHECK(ir::VerifySSA(z_ssa));
} }
...@@ -20,7 +20,7 @@ TEST(IRSSA, Basic) { ...@@ -20,7 +20,7 @@ TEST(IRSSA, Basic) {
using namespace Halide::Internal; using namespace Halide::Internal;
using namespace tvm; using namespace tvm;
Var x("x"), y; Var x("x"), y;
auto z = x + y; auto z = Evaluate::make(x + y);
CHECK(ir::VerifySSA(z)); CHECK(ir::VerifySSA(z));
} }
......
import tvm
def test_inline():
m = tvm.Var('m')
A = tvm.placeholder((m,), name='A')
T = tvm.compute((m,), lambda i,: A(i) + 10, name='T')
X = T(100)
stmt = tvm.make.Evaluate(T(10) + 11 * T(100))
stmt = tvm.ir_pass.Inline(
T, T.source_op.iter_var, T.source_op.body, stmt)
print(stmt)
assert(tvm.ir_pass.VerifySSA(stmt))
if __name__ == "__main__":
test_inline()
import tvm
def test_verify_ssa():
x = tvm.Var('x')
y = tvm.Var()
z = tvm.make.Evaluate(x + y)
assert(tvm.ir_pass.VerifySSA(z))
def test_convert_ssa():
x = tvm.Var('x')
y = tvm.Var()
let = tvm.make.Let(x, 1, x + 1)
z = tvm.make.Evaluate(let + let)
assert(not tvm.ir_pass.VerifySSA(z))
z_ssa = tvm.ir_pass.ConvertSSA(z)
assert(tvm.ir_pass.VerifySSA(z_ssa))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment