Commit 4bb3c35a by Tianqi Chen Committed by GitHub

[REFACTOR/PASS] Formalize argument bind and match util (#214)

* [REFACTOR/PASS] Formalize argument bind and match util

* grammar
parent 3c191595
......@@ -64,6 +64,14 @@ bool HasSideEffect(const Expr& e);
bool ExprUseVar(const Expr& e, const Var& v);
/*!
* \brief Whether e expression used any var in variable set..
* \param e The expression to be checked.
* \param vset The variable set.
* \return Whether e uses vset.
*/
bool ExprUseVar(const Expr& e, const std::unordered_set<const Variable*>& vset);
/*!
* \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted.
* \return The converted form.
......@@ -83,6 +91,24 @@ Stmt CanonicalSimplify(Stmt stmt);
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt,
const std::unordered_map<const Variable*, Expr>& value_map);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param expr The source expression to be substituted
* \param value_map The map of new values.
* \return The converted expression.
*/
Expr Substitute(Expr expr,
const std::unordered_map<const Variable*, Expr>& value_map);
/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map);
/*!
......
......@@ -7,7 +7,6 @@ from . import expr as _expr
from . import collections as _collections
from ._ffi.function import _init_api
@register_node
class Buffer(NodeBase):
"""Symbolic data buffer in TVM.
......@@ -24,16 +23,19 @@ class Buffer(NodeBase):
"""
pass
@register_node
class Split(NodeBase):
"""Split operation on axis."""
pass
@register_node
class Fuse(NodeBase):
"""Fuse operation on axis."""
pass
@register_node
class IterVar(NodeBase, _expr.ExprOp):
"""Represent iteration variable.
......
......@@ -30,6 +30,11 @@ TVM_REGISTER_API("ir_pass.Equal")
}
});
TVM_REGISTER_API("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ExprUseVar(args[0].operator Expr(), args[1].operator Var());
});
TVM_REGISTER_API("ir_pass.PostOrderVisit")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1];
......@@ -69,7 +74,6 @@ REGISTER_PASS1(CanonicalSimplify);
REGISTER_PASS4(Inline);
REGISTER_PASS2(StorageFlatten);
REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS2(ExprUseVar);
REGISTER_PASS4(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
......
......@@ -215,11 +215,7 @@ void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64,
if (t.is_int()) {
this->PushOp(op_int64);
} else if (t.is_uint()) {
if (t.bits() <= 32) {
this->PushOp(op_int64);
} else {
LOG(FATAL) << "Cannot handle uint64_t in StackVM";
}
this->PushOp(op_int64);
} else {
this->PushOp(StackVM::CodeI64ToF64(op_int64));
}
......
/*!
* Copyright (c) 2017 by Contributors
* \file arg_binder.cc
* \brief Helper utility to match and bind arguments.
*/
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/runtime/device_api.h>
#include "./ir_util.h"
#include "./arg_binder.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace ir {
void BinderAddAssert(Expr cond,
const std::string& arg_name,
std::vector<Stmt>* asserts) {
cond = Simplify(cond);
if (is_zero(cond)) {
LOG(FATAL) << "Bind have an unmet assertion: "
<< cond << ", " << " on argument " << arg_name;
}
if (!is_one(cond)) {
std::ostringstream os;
os << "Argument " << arg_name << " has an unsatisfied constraint";
asserts->emplace_back(AssertStmt::make(cond, os.str()));
}
}
bool ArgBinder::Bind_(const Expr& arg,
const Expr& value,
const std::string& arg_name,
bool with_lets) {
CHECK_EQ(arg.type(), value.type());
if (const Variable* v = arg.as<Variable>()) {
auto it = def_map_->find(v);
if (it == def_map_->end()) {
Var v_arg(arg.node_);
defs_.emplace_back(v_arg);
if (with_lets) {
(*def_map_)[v] = arg;
init_nest_.emplace_back(LetStmt::make(v_arg, value, Evaluate::make(0)));
} else {
(*def_map_)[v] = value;
}
return true;
} else {
BinderAddAssert(it->second == value, arg_name, &asserts_);
}
} else {
BinderAddAssert(arg == value, arg_name, &asserts_);
}
return false;
}
void ArgBinder::Bind(const Expr& arg,
const Expr& value,
const std::string& arg_name,
bool with_let) {
Bind_(arg, value, arg_name, with_let);
}
void ArgBinder::BindArray(const Array<Expr>& arg,
const Array<Expr>& value,
const std::string& arg_name) {
CHECK_EQ(arg.size(), value.size())
<< "Argument " << arg_name << " array size mismatch";
for (size_t i = 0; i < arg.size(); ++i) {
std::ostringstream os;
os << arg_name << "[" << i << "]";
this->Bind(arg[i], value[i], os.str());
}
}
void ArgBinder::BindBuffer(const Buffer& arg,
const Buffer& value,
const std::string& arg_name) {
CHECK_EQ(arg->scope, value->scope)
<< "Argument " << arg_name
<< " Buffer bind scope mismatch";
this->Bind(arg->data, value->data, arg_name + ".data");
this->BindArray(arg->shape, value->shape, arg_name + ".shape");
this->BindArray(arg->strides, value->strides, arg_name + ".strides");
this->Bind(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset");
}
inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
return TVMStructGet(t, arr, 0, kind);
}
inline Stmt AssertNull(Var handle, std::string msg) {
return AssertStmt::make(Call::make(
Bool(1), intrinsic::tvm_handle_is_null,
{handle}, Call::PureIntrinsic), msg);
}
void ArgBinder::BindDLTensor(const Buffer& buffer,
const Expr& device_type,
const Expr& device_id,
const Var& handle,
const std::string& arg_name) {
const Type tvm_shape_type = TVMShapeIndexType();
const Type tvm_ndim_type = Int(32);
const Stmt nop = Evaluate::make(0);
// dimension checks
Expr v_ndim = TVMArrayGet(tvm_ndim_type, handle, intrinsic::kArrNDim);
Expr a_ndim = make_const(tvm_ndim_type,
static_cast<int64_t>(buffer->shape.size()));
std::ostringstream ndim_err_msg;
ndim_err_msg << arg_name
<< ".ndim is expected to equal "
<< buffer->shape.size();
asserts_.emplace_back(AssertStmt::make(a_ndim == v_ndim, ndim_err_msg.str()));
// type checks
Type dtype = buffer->dtype;
std::ostringstream type_err_msg;
type_err_msg << arg_name << ".dtype is expected to be " << dtype;
Expr cond = (TVMArrayGet(UInt(8), handle, intrinsic::kArrTypeCode) ==
UIntImm::make(UInt(8), dtype.code()) &&
TVMArrayGet(UInt(8), handle, intrinsic::kArrTypeBits) ==
UIntImm::make(UInt(8), dtype.bits()) &&
TVMArrayGet(UInt(16), handle, intrinsic::kArrTypeLanes) ==
UIntImm::make(UInt(16), dtype.lanes()));
asserts_.emplace_back(AssertStmt::make(cond, type_err_msg.str()));
// data field
if (Bind_(buffer->data, TVMArrayGet(Handle(), handle, intrinsic::kArrData),
arg_name + ".data", true)) {
Var vptr(buffer->data);
def_handle_dtype_.Set(vptr, make_const(buffer->dtype, 0));
// mark alignment of external bufs
init_nest_.emplace_back(AttrStmt::make(
vptr, ir::attr::storage_alignment,
IntImm::make(Int(32), runtime::kAllocAlignment), nop));
}
Var v_shape(arg_name + ".shape", Handle());
def_handle_dtype_.Set(v_shape, make_const(tvm_shape_type, 0));
init_nest_.emplace_back(LetStmt::make(
v_shape, TVMArrayGet(Handle(), handle, intrinsic::kArrShape), nop));
for (size_t k = 0; k < buffer->shape.size(); ++k) {
std::ostringstream field_name;
field_name << v_shape->name_hint << '[' << k << ']';
Bind_(buffer->shape[k],
cast(buffer->shape[k].type(),
Load::make(tvm_shape_type, v_shape,
IntImm::make(Int(32), k), const_true(1))),
field_name.str(), true);
}
// strides field
Var v_strides(arg_name + ".strides", Handle());
def_handle_dtype_.Set(v_strides, make_const(tvm_shape_type, 0));
init_nest_.emplace_back(LetStmt::make(
v_strides, TVMArrayGet(Handle(), handle, intrinsic::kArrStrides),
nop));
if (buffer->strides.size() == 0) {
std::ostringstream stride_err_msg;
stride_err_msg << arg_name << ".strides:"
<< " expected to be nullptr for contiguous array";
init_nest_.emplace_back(AssertNull(v_strides, stride_err_msg.str()));
} else {
for (size_t k = 0; k < buffer->strides.size(); ++k) {
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';
Bind_(buffer->strides[k],
cast(buffer->shape[k].type(),
Load::make(tvm_shape_type, v_strides,
IntImm::make(Int(32), k), const_true(1))),
field_name.str(), true);
}
}
// Byte_offset field.
int data_bytes = GetVectorBytes(buffer->dtype);
int64_t const_offset;
if (arith::GetConst(buffer->elem_offset, &const_offset)) {
Bind_(make_const(UInt(64), const_offset * data_bytes),
TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset),
arg_name + ".byte_offset", true);
} else {
Bind_(buffer->elem_offset,
cast(buffer->elem_offset.type(),
(TVMArrayGet(UInt(64), handle, intrinsic::kArrByteOffset) /
make_const(UInt(64), data_bytes))),
arg_name + ".elem_offset", true);
}
// device info.
Bind_(device_type,
TVMArrayGet(Int(32), handle, intrinsic::kArrDeviceType),
arg_name + ".device_type", true);
Bind_(device_id,
TVMArrayGet(Int(32), handle, intrinsic::kArrDeviceId),
arg_name + ".device_id", true);
}
} // namespace ir
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file arg_binder.h
* \brief Helper utility to match and bind arguments.
*/
#ifndef TVM_PASS_ARG_BINDER_H_
#define TVM_PASS_ARG_BINDER_H_
#include <tvm/expr.h>
#include <tvm/buffer.h>
#include <string>
#include <vector>
namespace tvm {
namespace ir {
/*!
* \brief Helper utility to generate match and bind of arguments.
*
* \note There is many places in TVM IR where we need argument bindings.
*
* Consider a function f(tA(shape=var(n)), tB(shape=3), tC(shape=(n+2)).
* Here n is a undefined variable that is decided by the outside, tB imposes
* a constraint such that it can only take tensor with shape 3, tC imposes
* another constraint that it's shape must equals n + 2.
* So if we call it with f(bufferA, bufferB, bufferC), we need to generate
* the following binding sequence:
* - define n = bufferA.shape[0]
* - assert bufferB.shape[0] == 3
* - assert bufferB.shape[1] == n + 3
*
* In general, this is a constraint solving problem. We have simplified assumption
* over the binding declaration, such that we require the variable occured in
* constraint must be declared in argument list. So it is illegal to have signature
* f(tA(shape=(n+3))) without any argument variable corresponds to n, even though
* it is already enough to derive n from the input argument.
*/
class ArgBinder {
public:
/*!
* \brief Constructor
* \param def_map A definition map that contains definition of known variables.
* ArgBinder will update this def_map when adding new definitions.
*/
explicit ArgBinder(
std::unordered_map<const Variable*, Expr>* def_map)
: def_map_(def_map) {
}
/*!
* \brief Try to bind arg to value, generate constraint if necessary.
* \param arg The argument to be binded.
* \param value The target expression value
* \param arg_name argument name.
* \param with_let Whether add lets during bind
*/
void Bind(const Expr& arg,
const Expr& value,
const std::string& arg_name,
bool with_let = false);
/*!
* \brief Bind array to array
* \param arg The argument to be binded.
* \param value The target expression value
* \param arg_name argument name.
*/
void BindArray(const Array<Expr>& arg,
const Array<Expr>& value,
const std::string& arg_name);
/*!
* \brief Bind symbolic buffer to another symbolic buffer
* \param arg The argument to be binded.
* \param value The target expression value
* \param arg_name argument name.
*/
void BindBuffer(const Buffer& arg,
const Buffer& value,
const std::string& arg_name);
/*!
* \brief Bind symbolic buffer to a DLTensor handle.
* \param buffer The argument buffer to be binded.
* \param device_type The device id to be binded.
* \param device_id The device id to be binded.
* \param handle The DLTensor handle.
* \param arg_name argument name.
*/
void BindDLTensor(const Buffer& buffer,
const Expr& device_type,
const Expr& device_id,
const Var& handle,
const std::string& arg_name);
/*! \return The defs generated in binding. */
const std::vector<Var>& defs() const {
return defs_;
}
/*! \return The asserts generated in binding */
const std::vector<Stmt>& asserts() const {
return asserts_;
}
/*!
* \brief Initialization nest generated
* This is only non-empty when BindDLTensor is called.
*
* \note The binder may choose to generate a let statement
* and simply put def_map to map Variable to itself,
* or update def_map to directly map to new value and not generate let statement.
*
* Let statement is usually generated when bind to DLTensor and memory load is involved.
* \return The initialization nest generated during binding.
*/
const std::vector<Stmt>& init_nest() const {
return init_nest_;
}
/*! \return Handle data type of the data */
const Map<Var, Expr>& def_handle_dtype() const {
return def_handle_dtype_;
}
private:
// Internal bind function
bool Bind_(const Expr& arg,
const Expr& value,
const std::string& arg_name,
bool with_lets);
/*! \brief The definition map, can be uses to substitute */
std::unordered_map<const Variable*, Expr>* def_map_;
/*! \brief defs generated in the current binder */
std::vector<Var> defs_;
/*! \brief Initialize nest */
std::vector<Stmt> init_nest_;
/*! \brief handle data type in the defintiions */
Map<Var, Expr> def_handle_dtype_;
/*! \brief asserts generated */
std::vector<Stmt> asserts_;
};
} // namespace ir
} // namespace tvm
#endif // TVM_PASS_ARG_BINDER_H_
/*!
* Copyright (c) 2017 by Contributors
* \file ir_util.cc
* \brief Helper functions to construct and compose IR nodes.
*/
#include "./ir_util.h"
namespace tvm {
namespace ir {
Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
// use reverse iteration
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
Stmt s = *ri;
if (s.as<For>()) {
auto n = std::make_shared<For>(*s.as<For>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<LetStmt>()) {
auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<AttrStmt>()) {
auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<IfThenElse>()) {
auto n = std::make_shared<IfThenElse>(*s.as<IfThenElse>());
CHECK(is_no_op(n->then_case));
CHECK(!n->else_case.defined());
n->then_case = body;
body = Stmt(n);
} else if (s.as<AssertStmt>()) {
body = Block::make(s, body);
} else if (s.as<Allocate>()) {
auto n = std::make_shared<Allocate>(*s.as<Allocate>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else {
LOG(FATAL) << "not supported nest type";
}
}
return body;
}
Stmt MergeNest(const std::vector<std::vector<Stmt> >& nest, Stmt body) {
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
body = MergeNest(*ri, body);
}
return body;
}
Stmt MergeSeq(const std::vector<Stmt>& seq) {
if (seq.size() == 0) return Evaluate::make(0);
Stmt body = seq[0];
for (size_t i = 1; i < seq.size(); ++i) {
body = Block::make(body, seq[i]);
}
return body;
}
} // namespace ir
} // namespace tvm
......@@ -11,6 +11,28 @@
namespace tvm {
namespace ir {
/*!
* \brief combine the nest stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body body
* \return The combined Stmt
*/
Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body);
/*!
* \brief combine the nest stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body body
* \return The combined Stmt
*/
Stmt MergeNest(const std::vector<std::vector<Stmt> >& nest, Stmt body);
/*!
* \brief combine sequence of operations.
* \param seq The sequence.
* \return The combined Stmt
*/
Stmt MergeSeq(const std::vector<Stmt>& seq);
/*!
* \brief update array with an unary function
......@@ -39,79 +61,6 @@ inline Array<T> UpdateArray(Array<T> arr, F fupdate) {
}
/*!
* \brief combine the nest stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body body
* \return The combined Stmt
*/
inline Stmt MergeNest(std::vector<Stmt> nest, Stmt body) {
// use reverse iteration
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
Stmt s = *ri;
if (s.as<For>()) {
auto n = std::make_shared<For>(*s.as<For>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<LetStmt>()) {
auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<AttrStmt>()) {
auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (s.as<IfThenElse>()) {
auto n = std::make_shared<IfThenElse>(*s.as<IfThenElse>());
CHECK(is_no_op(n->then_case));
CHECK(!n->else_case.defined());
n->then_case = body;
body = Stmt(n);
} else if (s.as<AssertStmt>()) {
body = Block::make(s, body);
} else if (s.as<Allocate>()) {
auto n = std::make_shared<Allocate>(*s.as<Allocate>());
CHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else {
LOG(FATAL) << "not supported nest type";
}
}
return body;
}
/*!
* \brief combine the nest stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body body
* \return The combined Stmt
*/
inline Stmt MergeNest(std::vector<std::vector<Stmt> > nest, Stmt body) {
for (auto ri = nest.rbegin(); ri != nest.rend(); ++ri) {
body = MergeNest(*ri, body);
}
return body;
}
/*!
* \brief combine sequence of operations.
* \param seq The sequence.
* \return The combined Stmt
*/
inline Stmt MergeSeq(const std::vector<Stmt>& seq) {
if (seq.size() == 0) return Evaluate::make(0);
Stmt body = seq[0];
for (size_t i = 1; i < seq.size(); ++i) {
body = Block::make(body, seq[i]);
}
return body;
}
/*!
* \brief Get construct from struct
* \param dtype The data type.
* \param handle the struct handle.
......@@ -176,7 +125,6 @@ inline Type APIType(Type t) {
CHECK(t.is_float());
return Float(64);
}
} // namespace ir
} // namespace tvm
#endif // TVM_PASS_IR_UTIL_H_
......@@ -12,21 +12,12 @@
#include <unordered_set>
#include "./ir_util.h"
#include "./arg_binder.h"
#include "../arithmetic/compute_expr.h"
namespace tvm {
namespace ir {
inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
return TVMStructGet(t, arr, 0, kind);
}
inline Stmt AssertNull(Var handle, std::string msg) {
return AssertStmt::make(Call::make(
Bool(1), intrinsic::tvm_handle_is_null,
{handle}, Call::PureIntrinsic), msg);
}
inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) {
return AssertStmt::make(lhs == rhs, msg);
}
......@@ -35,8 +26,6 @@ LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_unpacked_args) {
const Type tvm_shape_type = TVMShapeIndexType();
const Type tvm_ndim_type = Int(32);
const Stmt nop = Evaluate::make(0);
int num_args = static_cast<int>(api_args.size());
CHECK_LE(num_unpacked_args, num_args);
......@@ -48,14 +37,13 @@ LoweredFunc MakeAPI(Stmt body,
Var v_num_packed_args("num_args", Int(32));
// The arguments of the function.
Array<Var> args;
// The device context
Var device_type("dev_type"), device_id("dev_id");
// seq_init gives sequence of initialization
// seq_check gives sequence of later checks after iniit
std::vector<Stmt> seq_init, seq_check;
std::unordered_set<const Variable*> visited;
// the handle data types
Map<Var, Expr> handle_data_type;
// The device context
Var device_id, device_type;
std::unordered_map<const Variable*, Expr> vmap;
ArgBinder binder(&vmap);
// ---------------------------
// local function defintiions
// load i-th argument as type t
......@@ -81,25 +69,6 @@ LoweredFunc MakeAPI(Stmt body,
const Variable* v = api_args[i].as<Variable>();
return Var(os.str(), v ? v->type: Handle());
};
// Push related into assertions or variable defintion
// given the symbolic declaration and concrete value
auto f_push = [&](Expr sym, Expr value, std::string field) {
if (sym.as<Variable>()) {
// If sym is a Variable and this Variable is not yet defined
// add this to defintion.
Var v(sym.node_);
if (!visited.count(v.get())) {
seq_init.emplace_back(LetStmt::make(v, value, nop));
visited.insert(v.get());
return true;
}
}
// otherwise, assume sym is already defined, insert assertion.
std::ostringstream os;
os << "Field " << field << " has a unsatisfied constraint";
seq_check.emplace_back(MakeAssertEQ(sym, value, os.str()));
return false;
};
// ---------------------------
// start of logics
// add signiture for packed arguments.
......@@ -112,7 +81,6 @@ LoweredFunc MakeAPI(Stmt body,
seq_init.emplace_back(
MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
}
for (int i = 0; i < static_cast<int>(api_args.size()); ++i) {
Var v_arg = f_arg_decl(i);
if (i < num_packed_args) {
......@@ -148,117 +116,30 @@ LoweredFunc MakeAPI(Stmt body,
}
// add checks for functions.
if (api_args[i].as<Variable>()) {
f_push(Var(api_args[i].node_), v_arg, v_arg->name_hint);
binder.Bind(Var(api_args[i].node_), v_arg, v_arg->name_hint, true);
} else {
// Buffer checks
CHECK(api_args[i].as<BufferNode>())
<< "api_args can only be Buffer or Var";
Buffer buf(api_args[i].node_);
// dimension checks
Expr v_ndim = TVMArrayGet(tvm_ndim_type, v_arg, intrinsic::kArrNDim);
std::ostringstream ndim_err_msg;
ndim_err_msg << "arg_" << i
<< ".ndim is expected to equal "
<< buf->shape.size();
seq_init.emplace_back(
MakeAssertEQ(v_ndim,
make_const(tvm_ndim_type,
static_cast<int64_t>(buf->shape.size())),
ndim_err_msg.str()));
// type checks
Type dtype = buf->dtype;
std::ostringstream type_err_msg;
type_err_msg << "arg" << i << ".dtype is expected to be " << dtype;
Expr cond = (TVMArrayGet(UInt(8), v_arg, intrinsic::kArrTypeCode) ==
UIntImm::make(UInt(8), dtype.code()) &&
TVMArrayGet(UInt(8), v_arg, intrinsic::kArrTypeBits) ==
UIntImm::make(UInt(8), dtype.bits()) &&
TVMArrayGet(UInt(16), v_arg, intrinsic::kArrTypeLanes) ==
UIntImm::make(UInt(16), dtype.lanes()));
seq_init.emplace_back(AssertStmt::make(cond, type_err_msg.str()));
// Data Field
if (f_push(buf->data, TVMArrayGet(Handle(), v_arg, intrinsic::kArrData),
v_arg->name_hint + ".data")) {
Var vptr(buf->data);
handle_data_type.Set(vptr, make_const(buf->dtype, 0));
// mark storage alignment of external buffer arguments.
seq_init.emplace_back(AttrStmt::make(
vptr, ir::attr::storage_alignment,
IntImm::make(Int(32), runtime::kAllocAlignment), nop));
}
// shape field
Var v_shape(v_arg->name_hint + ".shape", Handle());
handle_data_type.Set(v_shape, make_const(tvm_shape_type, 0));
seq_init.emplace_back(LetStmt::make(
v_shape, TVMArrayGet(Handle(), v_arg, intrinsic::kArrShape), nop));
for (size_t k = 0; k < buf->shape.size(); ++k) {
std::ostringstream field_name;
field_name << v_shape->name_hint << '[' << k << ']';
f_push(buf->shape[k],
cast(buf->shape[k].type(),
Load::make(tvm_shape_type, v_shape,
IntImm::make(Int(32), k), const_true(1))),
field_name.str());
}
// strides field
Var v_strides(v_arg->name_hint + ".strides", Handle());
handle_data_type.Set(v_strides, make_const(tvm_shape_type, 0));
seq_init.emplace_back(LetStmt::make(
v_strides, TVMArrayGet(Handle(), v_arg, intrinsic::kArrStrides),
nop));
if (buf->strides.size() == 0) {
std::ostringstream stride_err_msg;
stride_err_msg << "arg_" << i << ".strides:"
<< " expected to be nullptr for contiguous array";
seq_init.emplace_back(AssertNull(v_strides, stride_err_msg.str()));
} else {
for (size_t k = 0; k < buf->strides.size(); ++k) {
std::ostringstream field_name;
field_name << v_strides->name_hint << '[' << k << ']';
f_push(buf->strides[k],
cast(buf->shape[k].type(),
Load::make(tvm_shape_type, v_strides,
IntImm::make(Int(32), k), const_true(1))),
field_name.str());
}
}
// Byte_offset field.
int data_bytes = GetVectorBytes(buf->dtype);
int64_t const_offset;
if (arith::GetConst(buf->elem_offset, &const_offset)) {
f_push(make_const(buf->elem_offset.type(), const_offset * data_bytes),
TVMArrayGet(UInt(64), v_arg, intrinsic::kArrByteOffset),
v_arg->name_hint + ".byte_offset");
} else {
f_push(buf->elem_offset,
cast(buf->elem_offset.type(),
(TVMArrayGet(UInt(64), v_arg, intrinsic::kArrByteOffset) /
make_const(UInt(64), data_bytes))),
v_arg->name_hint + ".elem_offset");
}
// device info.
f_push(device_id,
TVMArrayGet(Int(32), v_arg, intrinsic::kArrDeviceId),
v_arg->name_hint + ".device_id");
f_push(device_type,
TVMArrayGet(Int(32), v_arg, intrinsic::kArrDeviceType),
v_arg->name_hint + ".device_type");
binder.BindDLTensor(
buf, device_type, device_id, v_arg, v_arg->name_hint);
}
}
std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>();
n->name = name;
n->args = args;
n->handle_data_type = handle_data_type;
n->handle_data_type = binder.def_handle_dtype();
n->is_packed_func = num_unpacked_args == 0;
// Set device context
if (visited.count(device_id.get())) {
if (vmap.count(device_id.get())) {
Expr node = StringImm::make("default");
CHECK(visited.count(device_type.get()));
seq_init.push_back(AttrStmt::make(
CHECK(vmap.count(device_type.get()));
seq_check.push_back(AttrStmt::make(
node, attr::device_context_id, device_id, nop));
seq_init.push_back(AttrStmt::make(
seq_check.push_back(AttrStmt::make(
node, attr::device_context_type, device_type, nop));
Stmt set_device = IfThenElse::make(
device_type != kCPU, Evaluate::make(Call::make(
......@@ -267,7 +148,8 @@ LoweredFunc MakeAPI(Stmt body,
device_type, device_id}, Call::Intrinsic)));
body = Block::make(set_device, body);
}
n->body = MergeNest({seq_init, seq_check}, body);
n->body = MergeNest(
{seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
LoweredFunc f(n);
Array<Var> undefined = UndefinedVars(f->body, f->args);
if (undefined.size() != 0) {
......
......@@ -37,67 +37,110 @@ bool HasSideEffect(const Expr& e) {
class IRSubstitue : public IRMutator {
public:
explicit IRSubstitue(
const std::unordered_map<const Variable*, Expr>& smap)
: smap_(smap) {
}
Expr Mutate_(const Variable* op, const Expr& e) final {
auto it = smap.find(op);
if (it != smap.end()) {
auto it = smap_.find(op);
if (it != smap_.end()) {
return it->second;
} else {
return e;
}
}
std::unordered_map<const Variable*, Expr> smap;
private:
const std::unordered_map<const Variable*, Expr>& smap_;
};
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
Stmt Substitute(Stmt stmt,
const std::unordered_map<const Variable*, Expr>& value_map) {
if (value_map.size() == 0) return stmt;
IRSubstitue m;
for (auto kv : value_map) {
m.smap[kv.first.get()] = kv.second;
return IRSubstitue(value_map).Mutate(stmt);
}
Expr Substitute(Expr expr,
const std::unordered_map<const Variable*, Expr>& value_map) {
if (value_map.size() == 0) return expr;
return IRSubstitue(value_map).Mutate(expr);
}
Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
std::unordered_map<const Variable*, Expr> vmap;
for (const auto& kv : value_map) {
vmap[kv.first.get()] = kv.second;
}
return m.Mutate(stmt);
return Substitute(stmt, vmap);
}
Expr Substitute(Expr expr, const Map<Var, Expr>& value_map) {
if (value_map.size() == 0) return expr;
IRSubstitue m;
for (auto kv : value_map) {
m.smap[kv.first.get()] = kv.second;
std::unordered_map<const Variable*, Expr> vmap;
for (const auto& kv : value_map) {
vmap[kv.first.get()] = kv.second;
}
return m.Mutate(expr);
return Substitute(expr, vmap);
}
class ExprUseVarVisitor : public IRVisitor {
class VarTouchVisitor : public IRVisitor {
public:
explicit ExprUseVarVisitor(const Variable* var)
: var_(var) {}
void Visit(const NodeRef& e) final {
if (use_var_) return;
IRVisitor::Visit(e);
}
void Visit_(const Variable* op) final {
if (op == var_) {
use_var_ = true;
}
Handle(op);
}
void Visit_(const Load* op) final {
if (op->buffer_var.get() == var_) {
use_var_ = true;
}
Handle(op->buffer_var.get());
IRVisitor::Visit_(op);
}
const Variable* var_;
virtual void Handle(const Variable* var) = 0;
bool use_var_{false};
};
class ExprUseVarVisitor : public VarTouchVisitor {
public:
explicit ExprUseVarVisitor(const Variable* var)
: var_(var) {}
void Handle(const Variable* var) final {
if (var == var_) use_var_ = true;
}
private:
const Variable* var_;
};
class ExprUseVSetVisitor : public VarTouchVisitor {
public:
explicit ExprUseVSetVisitor(
const std::unordered_set<const Variable*>& vset)
: vset_(vset) {}
void Handle(const Variable* var) final {
if (vset_.count(var)) use_var_ = true;
}
private:
const std::unordered_set<const Variable*>& vset_;
};
bool ExprUseVar(const Expr& e, const Var& v) {
ExprUseVarVisitor visitor(v.get());
visitor.Visit(e);
return visitor.use_var_;
}
bool ExprUseVar(const Expr& e,
const std::unordered_set<const Variable*>& vset) {
ExprUseVSetVisitor visitor(vset);
visitor.Visit(e);
return visitor.use_var_;
}
} // namespace ir
} // namespace tvm
......@@ -8,6 +8,8 @@
#include <tvm/ir_pass.h>
#include <tvm/buffer.h>
#include <unordered_map>
#include "./ir_util.h"
#include "./arg_binder.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/thread_storage_scope.h"
......@@ -156,30 +158,6 @@ class StorageFlattener : public IRMutator {
}
private:
// Bind the symbol sym to value if it is a Variable
// send a sequence of asserts if it is a constant constrant.
// hint_name: used for error message
// add_keys: a list of newly binded keys
// add_asserts: a list of asserts during the bind
void BindSymbol(Expr sym,
Expr value,
std::string hint_name,
std::vector<const Variable*>* add_keys,
std::vector<Stmt>* add_asserts) {
if (const Variable* v = sym.as<Variable>()) {
auto it = var_remap_.find(v);
if (it == var_remap_.end()) {
add_keys->push_back(v);
var_remap_[v] = value;
return;
}
}
// add assertions
std::ostringstream os;
os << "BufferBind constaint fail " << hint_name;
add_asserts->emplace_back(
AssertStmt::make(sym == value, os.str()));
}
// Start bind
Stmt HandleBufferBindScope(const AttrStmt* op) {
Array<NodeRef> arr(op->node.node_);
......@@ -215,47 +193,16 @@ class StorageFlattener : public IRMutator {
} else {
slice = slice.MakeStrideView();
}
CHECK_EQ(slice->strides.size(), buffer->strides.size());
// start binding
std::vector<const Variable*> keys;
std::vector<Stmt> asserts;
BindSymbol(buffer->data, slice->data,
buffer->name + ".data",
&keys, &asserts);
for (size_t i = 0; i < buffer->shape.size(); ++i) {
std::ostringstream field_name;
field_name << buffer->name << ".shape[" << i << ']';
BindSymbol(buffer->shape[i], slice->shape[i],
field_name.str(),
&keys, &asserts);
}
for (size_t i = 0; i < buffer->strides.size(); ++i) {
std::ostringstream field_name;
field_name << buffer->name << ".strides[" << i << ']';
BindSymbol(buffer->strides[i], slice->strides[i],
field_name.str(),
&keys, &asserts);
}
BindSymbol(buffer->elem_offset, slice->elem_offset,
buffer->name + ".elem_offset",
&keys, &asserts);
CHECK_EQ(buffer->scope, slice->scope)
<< "Buffer bind scope mismatch";
ArgBinder binder(&var_remap_);
binder.BindBuffer(Buffer(arr[0].node_), slice, buffer->name);
// Apply the remaps
Stmt body = this->Mutate(op->body);
for (size_t i = 0; i < asserts.size(); ++i) {
Stmt ret = Simplify(this->Mutate(asserts[i]));
if (const AssertStmt* assert_op = ret.as<AssertStmt>()) {
if (!is_zero(assert_op->condition)) {
body = Block::make(ret, body);
} else {
LOG(FATAL) << "BindBuffer have unmet assertion: " << ret;
}
}
}
Stmt body = MergeNest(binder.asserts(), op->body);
body = MergeNest(binder.init_nest(), body);
body = this->Mutate(body);
// remove the binds
for (const Variable* op : keys) {
var_remap_.erase(op);
for (const Var& v : binder.defs()) {
var_remap_.erase(v.get());
}
return body;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment