Commit 9ba40dc0 by Tianqi Chen Committed by GitHub

[CODEGEN/PASS] Improve callpacked lowering, allow pass array callback. (#110)

* [CODEGEN/PASS] Improve callpacked lowering, allow pass array callback.

* fix cython
parent d45b6d4b
...@@ -46,6 +46,7 @@ tvm.ir_pass ...@@ -46,6 +46,7 @@ tvm.ir_pass
tvm.ir_pass.VectorizeLoop tvm.ir_pass.VectorizeLoop
tvm.ir_pass.UnrollLoop tvm.ir_pass.UnrollLoop
tvm.ir_pass.StorageSync tvm.ir_pass.StorageSync
tvm.ir_pass.StorageRewrite
tvm.ir_pass.MakeAPI tvm.ir_pass.MakeAPI
tvm.ir_pass.SplitHostDevice tvm.ir_pass.SplitHostDevice
tvm.ir_pass.InjectVirtualThread tvm.ir_pass.InjectVirtualThread
...@@ -53,6 +54,8 @@ tvm.ir_pass ...@@ -53,6 +54,8 @@ tvm.ir_pass
tvm.ir_pass.RemoveNoOp tvm.ir_pass.RemoveNoOp
tvm.ir_pass.SplitPipeline tvm.ir_pass.SplitPipeline
tvm.ir_pass.LowerThreadAllreduce tvm.ir_pass.LowerThreadAllreduce
tvm.ir_pass.LowerIntrin
tvm.ir_pass.LowerPackedCall
tvm.ir_pass.NarrowChannelAccess tvm.ir_pass.NarrowChannelAccess
.. automodule:: tvm.ir_pass .. automodule:: tvm.ir_pass
......
...@@ -245,7 +245,6 @@ Expr max(Expr source, Array<IterVar> axis); ...@@ -245,7 +245,6 @@ Expr max(Expr source, Array<IterVar> axis);
*/ */
Expr min(Expr source, Array<IterVar> axis); Expr min(Expr source, Array<IterVar> axis);
// print functions for expr // print functions for expr
std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*) std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*)
......
...@@ -140,6 +140,10 @@ constexpr const char* volatile_scope = "volatile_scope"; ...@@ -140,6 +140,10 @@ constexpr const char* volatile_scope = "volatile_scope";
constexpr const char* storage_scope = "storage_scope"; constexpr const char* storage_scope = "storage_scope";
/*! \brief Mark storage scope of realization */ /*! \brief Mark storage scope of realization */
constexpr const char* realize_scope = "realize_scope"; constexpr const char* realize_scope = "realize_scope";
/*! \brief The allocation context for global malloc in host. */
constexpr const char* device_context_id = "device_context_id";
/*! \brief The device type. */
constexpr const char* device_context_type = "device_context_type";
/*! \brief Mark of loop scope */ /*! \brief Mark of loop scope */
constexpr const char* loop_scope = "loop_scope"; constexpr const char* loop_scope = "loop_scope";
/*! \brief Mark of reduce scope */ /*! \brief Mark of reduce scope */
...@@ -167,25 +171,24 @@ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope"; ...@@ -167,25 +171,24 @@ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
/*! \brief namespace of TVM Intrinsic functions */ /*! \brief namespace of TVM Intrinsic functions */
namespace intrinsic { namespace intrinsic {
// Most of the intrinsics is to enab
/*! /*!
* \brief See pesudo code * \brief See pesudo code
* *
* Type tvm_api_load_arg(TVMArg* args, int* args_type_id, i) { * Type tvm_struct_get(StructType* arr, int index, int field_id) {
* assert(arg_type_id[i] == typeid(Type)); * return arr[index]->field;
* return args[i];
* } * }
* \sa TVMStructFieldKind
*/ */
constexpr const char* tvm_api_load_arg = "tvm_api_load_arg"; constexpr const char* tvm_struct_get = "tvm_struct_get";
/*! /*!
* \brief See pesudo code * \brief See pesudo code
* *
* Type tvm_array_get_field(TVMArray* arr, int field_id) { * Handle tvm_struct_set(StructType* arr, int index, int field_id, value) {
* return arr->field; * arr[index]->field = value;
* } * }
* \sa TVMArrayFieldKind * \sa TVMStructFieldKind
*/ */
constexpr const char* tvm_array_get_field = "tvm_array_get_field"; constexpr const char* tvm_struct_set = "tvm_struct_set";
/*! /*!
* \brief See pesudo code * \brief See pesudo code
* *
...@@ -197,6 +200,48 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null"; ...@@ -197,6 +200,48 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
/*! /*!
* \brief See pesudo code * \brief See pesudo code
* *
* dtype in {shape, array, arg_value, arg_tcode}
*
* Handle tvm_stack_alloca(string dtype, int num) {
* return new on stack dtype[num];
* }
* \sa TVMStructFieldKind
*/
constexpr const char* tvm_stack_alloca = "tvm_stack_alloca";
/*!
* \brief Allocate a shape tuple on stack, return the handle.
*
* Handle tvm_stack_make_shape(list args) {
* ret = alloca stack int64_t[len(args)];
* for i in range(len(args)):
* ret[i] = args[i]
* return &ret[0];
* }
*/
constexpr const char* tvm_stack_make_shape = "tvm_stack_make_shape";
/*!
* \brief Allocate a NDArray(DLTensor) on stack, return the handle.
*
* Type tvm_stack_make_array(Expr data,
* Expr shape,
* Expr strides,
* Expr ndim,
* Expr dtype,
* Expr byte_offset) {
* ret = alloca stack DLTensor();
* ret->data = data;
* ret->shape = shape;
* ret->strides = strides != 0 ? strides : nullptr;
* ret->ndim = ndim;
* ret->dtype = dtype.type();
* ret->byte_offset = byte_offset;
* return ret;
* }
*/
constexpr const char* tvm_stack_make_array = "tvm_stack_make_array";
/*!
* \brief See pesudo code
*
* int tvm_call_packed(name, TVMValue* args) { * int tvm_call_packed(name, TVMValue* args) {
* ModuleNode* env = GetCurrentEnv(); * ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name); * const PackedFunc* f = env->GetFuncFromEnv(name);
...@@ -206,6 +251,23 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null"; ...@@ -206,6 +251,23 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
*/ */
constexpr const char* tvm_call_packed = "tvm_call_packed"; constexpr const char* tvm_call_packed = "tvm_call_packed";
/*! /*!
* \brief Lowered version of call packed, the space of value and
* type codes are explicitly allocated.
*
* int tvm_call_packed_lowered(name,
* TVMValue* value_stack,
* int* tcode_stack,
* int begin,
* int end) {
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* f->CallPacked(TVMArgs(value_stack[begin:end],
* tcode_stack[begin:end]),
* TVMRetValue(value_stack + end, tcode_stack + end));
* }
*/
constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered";
/*!
* \brief See pesudo code * \brief See pesudo code
* *
* int tvm_storage_sync(std::string storage_scope) { * int tvm_storage_sync(std::string storage_scope) {
...@@ -231,16 +293,24 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit"; ...@@ -231,16 +293,24 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit";
*/ */
constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce"; constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce";
/*! \brief The field id of each field in array */ /*! \brief The kind of structre field info */
enum TVMArrayFieldKind { enum TVMStructFieldKind : int {
kData = 0, // array head address
kNDim = 1, kArrAddr,
kShape = 2, kArrData,
kStrides = 3, kArrShape,
kTypeCode = 4, kArrStrides,
kTypeBits = 5, kArrNDim,
kTypeLanes = 6, kArrTypeCode,
kByteOffset = 7 kArrTypeBits,
kArrTypeLanes,
kArrByteOffset,
kArrDeviceId,
kArrDeviceType,
kArrKindBound_,
// TVMValue field
kTVMValueContent,
kTVMValueKindBound_
}; };
} // namespace intrinsic } // namespace intrinsic
......
...@@ -252,6 +252,13 @@ LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope); ...@@ -252,6 +252,13 @@ LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);
LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size); LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
/*! /*!
* \brief Lower packed function call.
* \param f The function to be lowered.
* \return Transformed function.
*/
LoweredFunc LowerPackedCall(LoweredFunc f);
/*!
* \brief Lower intrinsic function calls. * \brief Lower intrinsic function calls.
* \param f The device function to be lowered. * \param f The device function to be lowered.
* \param target The target device. * \param target The target device.
......
...@@ -10,7 +10,7 @@ from numbers import Number, Integral ...@@ -10,7 +10,7 @@ from numbers import Number, Integral
from ..base import _LIB, check_call from ..base import _LIB, check_call
from ..base import c_str, string_types from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric from ..node_generic import convert_to_node, NodeGeneric
from ..ndarray import TVMType, TVMByteArray, NDArrayBase from ..ndarray import TVMType, TVMByteArray, NDArrayBase, _make_array
from .types import TVMValue, TypeCode from .types import TVMValue, TypeCode
from .types import TVMPackedCFunc, TVMCFuncFinalizer from .types import TVMPackedCFunc, TVMCFuncFinalizer
from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func from .types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
...@@ -188,6 +188,7 @@ def _handle_return_func(x): ...@@ -188,6 +188,7 @@ def _handle_return_func(x):
handle = FunctionHandle(handle) handle = FunctionHandle(handle)
return _CLASS_FUNCTION(handle, False) return _CLASS_FUNCTION(handle, False)
# setup return handle for function type # setup return handle for function type
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
...@@ -195,7 +196,7 @@ C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func( ...@@ -195,7 +196,7 @@ C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE) _handle_return_func, TypeCode.FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE) _return_module, TypeCode.MODULE_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(x.v_handle, True)
_CLASS_MODULE = None _CLASS_MODULE = None
_CLASS_FUNCTION = None _CLASS_FUNCTION = None
......
...@@ -4,7 +4,7 @@ from cpython cimport Py_INCREF, Py_DECREF ...@@ -4,7 +4,7 @@ from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral from numbers import Number, Integral
from ..base import string_types from ..base import string_types
from ..node_generic import convert_to_node, NodeGeneric from ..node_generic import convert_to_node, NodeGeneric
from ..ndarray import NDArrayBase, TVMType, TVMByteArray from ..ndarray import NDArrayBase, TVMType, TVMByteArray, _make_array
print("TVM: Initializing cython mode...") print("TVM: Initializing cython mode...")
...@@ -29,7 +29,10 @@ cdef int tvm_callback(TVMValue* args, ...@@ -29,7 +29,10 @@ cdef int tvm_callback(TVMValue* args,
tcode == kFuncHandle or tcode == kFuncHandle or
tcode == kModuleHandle): tcode == kModuleHandle):
CALL(TVMCbArgToReturn(&value, tcode)) CALL(TVMCbArgToReturn(&value, tcode))
if tcode != kArrayHandle:
pyargs.append(make_ret(value, tcode)) pyargs.append(make_ret(value, tcode))
else:
pyargs.append(_make_array(ctypes_handle(value.v_handle), True))
try: try:
rv = local_pyfunc(*pyargs) rv = local_pyfunc(*pyargs)
except Exception: except Exception:
...@@ -64,7 +67,9 @@ def convert_to_tvm_func(object pyfunc): ...@@ -64,7 +67,9 @@ def convert_to_tvm_func(object pyfunc):
<void*>(pyfunc), <void*>(pyfunc),
tvm_callback_finalize, tvm_callback_finalize,
&chandle)) &chandle))
return _CLASS_FUNCTION(ctypes_handle(chandle), False) ret = _CLASS_FUNCTION(None, False)
(<FunctionBase>ret).chandle = chandle
return ret
cdef inline void make_arg(object arg, cdef inline void make_arg(object arg,
......
...@@ -198,9 +198,9 @@ def sync(ctx): ...@@ -198,9 +198,9 @@ def sync(ctx):
class NDArrayBase(object): class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime.""" """A simple Device/CPU Array object in runtime."""
__slots__ = ["handle"] __slots__ = ["handle", "is_view"]
# pylint: disable=no-member # pylint: disable=no-member
def __init__(self, handle): def __init__(self, handle, is_view=False):
"""Initialize the function with handle """Initialize the function with handle
Parameters Parameters
...@@ -209,8 +209,10 @@ class NDArrayBase(object): ...@@ -209,8 +209,10 @@ class NDArrayBase(object):
the handle to the underlying C++ TVMArray the handle to the underlying C++ TVMArray
""" """
self.handle = handle self.handle = handle
self.is_view = is_view
def __del__(self): def __del__(self):
if not self.is_view:
check_call(_LIB.TVMArrayFree(self.handle)) check_call(_LIB.TVMArrayFree(self.handle))
@property @property
...@@ -302,6 +304,10 @@ class NDArrayBase(object): ...@@ -302,6 +304,10 @@ class NDArrayBase(object):
raise ValueError("Unsupported target type %s" % str(type(target))) raise ValueError("Unsupported target type %s" % str(type(target)))
return target return target
def _make_array(handle, is_view):
handle = ctypes.cast(handle, TVMArrayHandle)
return _CLASS_NDARRAY(handle, is_view)
_CLASS_NDARRAY = None _CLASS_NDARRAY = None
def _set_class_ndarray(cls): def _set_class_ndarray(cls):
......
...@@ -146,7 +146,7 @@ def build(sch, ...@@ -146,7 +146,7 @@ def build(sch,
warp_size = 32 if target == "cuda" else 1 warp_size = 32 if target == "cuda" else 1
fapi = ir_pass.LowerThreadAllreduce(fapi, warp_size) fapi = ir_pass.LowerThreadAllreduce(fapi, warp_size)
fsplits = [s for s in ir_pass.SplitHostDevice(fapi)] fsplits = [s for s in ir_pass.SplitHostDevice(fapi)]
fsplits[0] = ir_pass.LowerPackedCall(fsplits[0])
if len(fsplits) > 1: if len(fsplits) > 1:
if not target_host: if not target_host:
target_host = "llvm" if codegen.enabled("llvm") else "stackvm" target_host = "llvm" if codegen.enabled("llvm") else "stackvm"
......
"""Intrinsics and math functions in TVM.""" """Intrinsics and math functions in TVM."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .expr import Call as _Call
from . import make as _make
from ._ffi.function import register_func as _register_func from ._ffi.function import register_func as _register_func
from .api import convert from . import make as _make
from .api import convert, const
from .expr import Call as _Call
from .schedule import Buffer as _Buffer
def _pack_buffer(buf):
"""Build intrinsics that packs the buffer.
"""
assert buf.shape
shape = _make.Call("handle", "tvm_stack_make_shape", buf.shape,
_Call.Intrinsic, None, 0)
strides = _make.Call("handle", "tvm_stack_make_shape", buf.strides,
_Call.Intrinsic, None, 0) if buf.strides else 0
pack_args = [buf.data,
shape,
strides,
len(buf.shape),
const(0, dtype=buf.dtype),
buf.byte_offset]
return _make.Call("handle", "tvm_stack_make_array",
pack_args, _Call.Intrinsic, None, 0)
def call_packed(*args): def call_packed(*args):
"""Build expression by call an external packed function """Build expression by call an external packed function.
The argument to packed function can be Expr or Buffer.
The argument is the corresponding POD type when Expr is presented.
When the argument is Buffer, the corresponding PackedFunc
will recieve an TVMArrayHandle whose content is valid during the callback period.
If the PackedFunc is a python callback, then the corresponding argument is NDArray.
Parameters Parameters
---------- ----------
args : list args : list of Expr or Buffer.
Positional arguments. Positional arguments.
Returns
-------
call : Expr
The call expression.
See Also
--------
tvm.extern : Create tensor with extern function call.
""" """
call_args = [_pack_buffer(x) if isinstance(x, _Buffer) else x for x in args]
return _make.Call( return _make.Call(
"int32", "tvm_call_packed", args, _Call.Intrinsic, None, 0) "int32", "tvm_call_packed", call_args, _Call.Intrinsic, None, 0)
def call_pure_intrin(dtype, func_name, *args): def call_pure_intrin(dtype, func_name, *args):
...@@ -34,6 +69,11 @@ def call_pure_intrin(dtype, func_name, *args): ...@@ -34,6 +69,11 @@ def call_pure_intrin(dtype, func_name, *args):
args : list args : list
Positional arguments. Positional arguments.
Returns
-------
call : Expr
The call expression.
""" """
args = convert(args) args = convert(args)
return _make.Call( return _make.Call(
...@@ -53,6 +93,11 @@ def call_pure_extern(dtype, func_name, *args): ...@@ -53,6 +93,11 @@ def call_pure_extern(dtype, func_name, *args):
args : list args : list
Positional arguments. Positional arguments.
Returns
-------
call : Expr
The call expression.
""" """
return _make.Call( return _make.Call(
dtype, func_name, convert(args), _Call.PureExtern, None, 0) dtype, func_name, convert(args), _Call.PureExtern, None, 0)
......
...@@ -36,7 +36,6 @@ TVM_REGISTER_API("_const") ...@@ -36,7 +36,6 @@ TVM_REGISTER_API("_const")
} }
}); });
TVM_REGISTER_API("_str") TVM_REGISTER_API("_str")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ir::StringImm::make(args[0]); *ret = ir::StringImm::make(args[0]);
......
...@@ -76,5 +76,6 @@ REGISTER_PASS2(SplitPipeline); ...@@ -76,5 +76,6 @@ REGISTER_PASS2(SplitPipeline);
REGISTER_PASS1(NarrowChannelAccess); REGISTER_PASS1(NarrowChannelAccess);
REGISTER_PASS2(LowerThreadAllreduce); REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerIntrin); REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1(LowerPackedCall);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -89,8 +89,7 @@ void CodeGenC::PrintSSAAssign( ...@@ -89,8 +89,7 @@ void CodeGenC::PrintSSAAssign(
// Print a reference expression to a buffer. // Print a reference expression to a buffer.
std::string CodeGenC::GetBufferRef( std::string CodeGenC::GetBufferRef(
const Variable* buffer, Type t, const Variable* buffer, Expr index) {
Type t, Expr index) {
std::ostringstream os; std::ostringstream os;
std::string vid = GetVarID(buffer); std::string vid = GetVarID(buffer);
std::string scope; std::string scope;
...@@ -151,6 +150,58 @@ std::string CodeGenC::GetBufferRef( ...@@ -151,6 +150,58 @@ std::string CodeGenC::GetBufferRef(
return os.str(); return os.str();
} }
// Print a reference expression to a buffer.
std::string CodeGenC::GetStructRef(
Type t, const Expr& buffer, const Expr& index, int kind) {
if (kind < intrinsic::kArrKindBound_) {
std::ostringstream os;
os << "(((TVMArray*)";
this->PrintExpr(buffer, os);
os << ")";
if (kind == intrinsic::kArrAddr) {
os << " + ";
this->PrintExpr(index, os);
os << ")";
return os.str();
}
os << '[';
this->PrintExpr(index, os);
os << "].";
// other case: get fields.
switch (kind) {
case intrinsic::kArrData: os << "data"; break;
case intrinsic::kArrShape: os << "shape"; break;
case intrinsic::kArrStrides: os << "strides"; break;
case intrinsic::kArrNDim: os << "ndim"; break;
case intrinsic::kArrTypeCode: os << "dtype.code"; break;
case intrinsic::kArrTypeBits: os << "dtype.bits"; break;
case intrinsic::kArrTypeLanes: os << "dtype.lanes"; break;
case intrinsic::kArrDeviceId: os << "ctx.device_id"; break;
case intrinsic::kArrDeviceType: os << "ctx.device_type"; break;
default: LOG(FATAL) << "unknown field code";
}
os << ')';
return os.str();
} else {
CHECK_LT(kind, intrinsic::kTVMValueKindBound_);
std::ostringstream os;
os << "(((TVMValue*)";
this->PrintExpr(buffer, os);
os << ")[" << index << "].";
if (t.is_handle()) {
os << "v_handle";
} else if (t.is_float()) {
os << "v_float64";
} else if (t.is_int()) {
os << "v_int64";
} else {
LOG(FATAL) << "donot know how to handle type" << t;
}
os << ")";
return os.str();
}
}
bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const { bool CodeGenC::HandleTypeMatch(const Variable* buf_var, Type t) const {
auto it = handle_data_type_.find(buf_var); auto it = handle_data_type_.find(buf_var);
...@@ -182,15 +233,15 @@ void CodeGenC::PrintVecElemStore(const std::string& vec, ...@@ -182,15 +233,15 @@ void CodeGenC::PrintVecElemStore(const std::string& vec,
<< " = " << value << ";\n"; << " = " << value << ";\n";
} }
std::string CodeGenC::GetVecLoad(const Variable* buffer, std::string CodeGenC::GetVecLoad(
Type t, Expr base) { Type t, const Variable* buffer, Expr base) {
return GetBufferRef(buffer, t, base); return GetBufferRef(t, buffer, base);
} }
void CodeGenC::PrintVecStore(const Variable* buffer, void CodeGenC::PrintVecStore(const Variable* buffer,
Type t, Expr base, Type t, Expr base,
const std::string& value) { const std::string& value) {
std::string ref = GetBufferRef(buffer, t, base); std::string ref = GetBufferRef(t, buffer, base);
this->PrintIndent(); this->PrintIndent();
stream << ref << " = " << value << ";\n"; stream << ref << " = " << value << ";\n";
} }
...@@ -430,42 +481,11 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*) ...@@ -430,42 +481,11 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
<< " + "; << " + ";
this->PrintExpr(l->index, os); this->PrintExpr(l->index, os);
os << ')'; os << ')';
} else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) { } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U); CHECK_EQ(op->args.size(), 3U);
if (!op->type.is_handle()) { os << GetStructRef(
os << '('; op->type, op->args[0], op->args[1],
this->PrintType(op->type, os); op->args[2].as<IntImm>()->value);
os << ')';
}
os << "(((TVMArg*)";
this->PrintExpr(op->args[0], os);
os << ")[" << op->args[2] << "].";
if (op->type.is_handle()) {
os << "v_handle";
} else if (op->type.is_float()) {
os << "v_double";
} else if (op->type.is_int() || op->type.is_uint()) {
os << "v_long";
} else {
LOG(FATAL) << "donot know how to handle type" << op->type;
}
os << ")";
} else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) {
CHECK_EQ(op->args.size(), 2U);
os << "(((TVMArray*)";
this->PrintExpr(op->args[0], os);
os << ")->";
switch (op->args[1].as<IntImm>()->value) {
case intrinsic::kData: os << "data"; break;
case intrinsic::kShape: os << "shape"; break;
case intrinsic::kStrides: os << "strides"; break;
case intrinsic::kNDim: os << "ndim"; break;
case intrinsic::kTypeCode: os << "dtype.type_code"; break;
case intrinsic::kTypeBits: os << "dtype.bits"; break;
case intrinsic::kTypeLanes: os << "dtype.lanes"; break;
default: LOG(FATAL) << "unknown field code";
}
os << ')';
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
CHECK_EQ(op->args.size(), 1U); CHECK_EQ(op->args.size(), 1U);
os << "("; os << "(";
...@@ -513,12 +533,12 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) ...@@ -513,12 +533,12 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
int lanes = op->type.lanes(); int lanes = op->type.lanes();
// delcare type. // delcare type.
if (op->type.lanes() == 1) { if (op->type.lanes() == 1) {
std::string ref = GetBufferRef(op->buffer_var.get(), op->type, op->index); std::string ref = GetBufferRef(op->type, op->buffer_var.get(), op->index);
os << ref; os << ref;
} else { } else {
Expr base; Expr base;
if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) { if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) {
std::string ref = GetVecLoad(op->buffer_var.get(), op->type, base); std::string ref = GetVecLoad(op->type, op->buffer_var.get(), base);
os << ref; os << ref;
} else { } else {
// load seperately. // load seperately.
...@@ -552,7 +572,7 @@ void CodeGenC::VisitStmt_(const Store* op) { ...@@ -552,7 +572,7 @@ void CodeGenC::VisitStmt_(const Store* op) {
Type t = op->value.type(); Type t = op->value.type();
if (t.lanes() == 1) { if (t.lanes() == 1) {
std::string value = this->PrintExpr(op->value); std::string value = this->PrintExpr(op->value);
std::string ref = this->GetBufferRef(op->buffer_var.get(), t, op->index); std::string ref = this->GetBufferRef(t, op->buffer_var.get(), op->index);
this->PrintIndent(); this->PrintIndent();
stream << ref << " = " << value << ";\n"; stream << ref << " = " << value << ";\n";
} else { } else {
...@@ -744,14 +764,25 @@ void CodeGenC::VisitStmt_(const Block *op) { ...@@ -744,14 +764,25 @@ void CodeGenC::VisitStmt_(const Block *op) {
void CodeGenC::VisitStmt_(const Evaluate *op) { void CodeGenC::VisitStmt_(const Evaluate *op) {
if (is_const(op->value)) return; if (is_const(op->value)) return;
const Call* call = op->value.as<Call>(); const Call* call = op->value.as<Call>();
if (call) {
if (call && call->is_intrinsic(intrinsic::tvm_storage_sync)) { if (call->is_intrinsic(intrinsic::tvm_storage_sync)) {
this->PrintStorageSync(call); this->PrintStorageSync(call); return;
} else { } else if (call->is_intrinsic(intrinsic::tvm_struct_set)) {
CHECK_EQ(call->args.size(), 4);
std::string value = PrintExpr(call->args[3]);
std::string ref = GetStructRef(
call->args[3].type(),
call->args[0],
call->args[1],
call->args[2].as<IntImm>()->value);
this->PrintIndent();
this->stream << ref << " = " << value << ";\n";
return;
}
}
std::string vid = this->PrintExpr(op->value); std::string vid = this->PrintExpr(op->value);
this->PrintIndent(); this->PrintIndent();
this->stream << "(void)" << vid << ";\n"; this->stream << "(void)" << vid << ";\n";
}
} }
void CodeGenC::VisitStmt_(const ProducerConsumer *op) { void CodeGenC::VisitStmt_(const ProducerConsumer *op) {
......
...@@ -133,8 +133,7 @@ class CodeGenC : ...@@ -133,8 +133,7 @@ class CodeGenC :
const std::string&op, Type op_type, const std::string&op, Type op_type,
Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*) Expr lhs, Expr rhs, std::ostream& os); // NOLINT(*)
// print vector load // print vector load
virtual std::string GetVecLoad(const Variable* buffer, virtual std::string GetVecLoad(Type t, const Variable* buffer, Expr base);
Type t, Expr base);
// print vector store // print vector store
virtual void PrintVecStore(const Variable* buffer, virtual void PrintVecStore(const Variable* buffer,
Type t, Expr base, Type t, Expr base,
...@@ -146,11 +145,13 @@ class CodeGenC : ...@@ -146,11 +145,13 @@ class CodeGenC :
virtual void PrintVecElemStore( virtual void PrintVecElemStore(
const std::string& vec, Type t, int i, const std::string& value); const std::string& vec, Type t, int i, const std::string& value);
protected: protected:
// Print reference to struct location
std::string GetStructRef(
Type t, const Expr& buffer, const Expr& index, int kind);
// print reference to a buffer as type t in index. // print reference to a buffer as type t in index.
std::string GetBufferRef(const Variable* buffer, std::string GetBufferRef(
Type t, Expr index); Type t, const Variable* buffer, Expr index);
/*! /*!
* \brief If buffer is allocated as type t. * \brief If buffer is allocated as type t.
* \param buf_var The buffer variable. * \param buf_var The buffer variable.
......
...@@ -95,8 +95,8 @@ void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, Type t, ...@@ -95,8 +95,8 @@ void CodeGenOpenCL::PrintVecAddr(const Variable* buffer, Type t,
os << GetVarID(buffer) << " + "; os << GetVarID(buffer) << " + ";
PrintExpr(base, os); PrintExpr(base, os);
} }
std::string CodeGenOpenCL::GetVecLoad(const Variable* buffer, std::string CodeGenOpenCL::GetVecLoad(
Type t, Expr base) { Type t, const Variable* buffer, Expr base) {
std::ostringstream os; std::ostringstream os;
os << "vload" << t.lanes() << "(0, "; os << "vload" << t.lanes() << "(0, ";
PrintVecAddr(buffer, t, base, os); PrintVecAddr(buffer, t, base, os);
......
...@@ -24,8 +24,8 @@ class CodeGenOpenCL : public CodeGenC { ...@@ -24,8 +24,8 @@ class CodeGenOpenCL : public CodeGenC {
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
void PrintStorageSync(const Call* op) final; // NOLINT(*) void PrintStorageSync(const Call* op) final; // NOLINT(*)
void PrintType(Type t, std::ostream& os) const final; // NOLINT(*) void PrintType(Type t, std::ostream& os) const final; // NOLINT(*)
std::string GetVecLoad(const Variable* buffer, std::string GetVecLoad(Type t, const Variable* buffer,
Type t, Expr base) final; Expr base) final;
void PrintVecStore(const Variable* buffer, void PrintVecStore(const Variable* buffer,
Type t, Expr base, Type t, Expr base,
const std::string& value) final; // NOLINT(*) const std::string& value) final; // NOLINT(*)
......
...@@ -190,6 +190,7 @@ class CodeGenLLVM : ...@@ -190,6 +190,7 @@ class CodeGenLLVM :
llvm::Value* CreateBroadcast(llvm::Value* value, int lanes); llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
llvm::Value* GetConstString(const std::string& str); llvm::Value* GetConstString(const std::string& str);
llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index); llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
llvm::Value* CreateStructRefPtr(Type t, llvm::Value* buffer, llvm::Value* index, int kind);
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value); llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
llvm::Value* GetPackedFuncHandle(const std::string& str); llvm::Value* GetPackedFuncHandle(const std::string& str);
// Vector concatenation. // Vector concatenation.
......
...@@ -37,7 +37,7 @@ void InitializeLLVM() { ...@@ -37,7 +37,7 @@ void InitializeLLVM() {
} }
std::pair<llvm::TargetMachine*, std::string> std::pair<llvm::TargetMachine*, std::string>
LLVMGetTarget(const std::string& target_str) { GetLLVMTarget(const std::string& target_str) {
// setup target triple // setup target triple
std::string target_triple; std::string target_triple;
CHECK_EQ(target_str.substr(0, 4), "llvm"); CHECK_EQ(target_str.substr(0, 4), "llvm");
......
...@@ -55,7 +55,7 @@ void InitializeLLVM(); ...@@ -55,7 +55,7 @@ void InitializeLLVM();
* \return Pair of target machine and target triple. * \return Pair of target machine and target triple.
*/ */
std::pair<llvm::TargetMachine*, std::string> std::pair<llvm::TargetMachine*, std::string>
LLVMGetTarget(const std::string& target_str); GetLLVMTarget(const std::string& target_str);
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
......
...@@ -94,7 +94,7 @@ class LLVMModuleNode : public runtime::ModuleNode { ...@@ -94,7 +94,7 @@ class LLVMModuleNode : public runtime::ModuleNode {
void Init(const Array<LoweredFunc>& funcs, std::string target) { void Init(const Array<LoweredFunc>& funcs, std::string target) {
InitializeLLVM(); InitializeLLVM();
std::tie(tm_, target_triple_) = LLVMGetTarget(target); std::tie(tm_, target_triple_) = GetLLVMTarget(target);
CHECK_NE(funcs.size(), 0U); CHECK_NE(funcs.size(), 0U);
ctx_ = std::make_shared<llvm::LLVMContext>(); ctx_ = std::make_shared<llvm::LLVMContext>();
CodeGenLLVM cg; CodeGenLLVM cg;
......
...@@ -70,25 +70,6 @@ int CodeGenStackVM::AllocVarID(const Variable* v) { ...@@ -70,25 +70,6 @@ int CodeGenStackVM::AllocVarID(const Variable* v) {
return vid; return vid;
} }
void CodeGenStackVM::PushCallPacked(
int fid, const std::vector<int>& arg_type_codes) {
StackVM::Code code;
// CALL_PACKED_FUNC
code.op_code = StackVM::CALL_PACKED_FUNC;
vm_.code.push_back(code);
// num_args
code.v_int = static_cast<int>(arg_type_codes.size());
vm_.code.push_back(code);
// fid
code.v_int = fid;
vm_.code.push_back(code);
// type codes.
for (int tcode : arg_type_codes) {
code.v_int = tcode;
vm_.code.push_back(code);
}
}
int CodeGenStackVM::GetVarID(const Variable* v) const { int CodeGenStackVM::GetVarID(const Variable* v) const {
auto it = var_idmap_.find(v); auto it = var_idmap_.find(v);
CHECK(it != var_idmap_.end()) CHECK(it != var_idmap_.end())
...@@ -97,26 +78,33 @@ int CodeGenStackVM::GetVarID(const Variable* v) const { ...@@ -97,26 +78,33 @@ int CodeGenStackVM::GetVarID(const Variable* v) const {
} }
void CodeGenStackVM::VisitExpr_(const Load* op) { void CodeGenStackVM::VisitExpr_(const Load* op) {
this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get())); this->Push(op->buffer_var);
if (op->type == UInt(32) && op->index.as<IntImm>()) { StackVM::OpCode code = StackVM::GetLoad(Type2TVMType(op->type));
this->PushOp(StackVM::ARRAY_LOAD_UINT32, op->index.as<IntImm>()->value); if (const IntImm* index = op->index.as<IntImm>()) {
this->PushOp(code, op->index.as<IntImm>()->value);
} else { } else {
this->Push(op->index); this->Push(op->index);
this->PushOp(StackVM::PUSH_I64, op->type.element_of().bytes()); this->PushOp(StackVM::PUSH_I64, op->type.element_of().bytes());
this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD); this->PushOp(StackVM::ADDR_ADD);
this->PushOp(StackVM::GetLoad(Type2TVMType(op->type))); this->PushOp(code, 0);
} }
} }
void CodeGenStackVM::VisitStmt_(const Store* op) { void CodeGenStackVM::VisitStmt_(const Store* op) {
this->PushOp(StackVM::LOAD_HEAP, GetVarID(op->buffer_var.get())); this->Push(op->buffer_var);
StackVM::OpCode code = StackVM::GetStore(Type2TVMType(op->value.type()));
if (const IntImm* index = op->index.as<IntImm>()) {
this->Push(op->value);
this->PushOp(code, op->index.as<IntImm>()->value);
} else {
this->Push(op->index); this->Push(op->index);
this->PushOp(StackVM::PUSH_I64, op->value.type().element_of().bytes()); this->PushOp(StackVM::PUSH_I64, op->value.type().element_of().bytes());
this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD); this->PushOp(StackVM::ADDR_ADD);
this->Push(op->value); this->Push(op->value);
this->PushOp(StackVM::GetStore(Type2TVMType(op->value.type()))); this->PushOp(code, 0);
}
} }
void CodeGenStackVM::VisitStmt_(const Allocate* op) { void CodeGenStackVM::VisitStmt_(const Allocate* op) {
...@@ -141,41 +129,29 @@ void CodeGenStackVM::VisitExpr_(const Call* op) { ...@@ -141,41 +129,29 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
this->PushOp(StackVM::PUSH_I64, l->type.element_of().bytes()); this->PushOp(StackVM::PUSH_I64, l->type.element_of().bytes());
this->PushOp(StackVM::MUL_I64); this->PushOp(StackVM::MUL_I64);
this->PushOp(StackVM::ADDR_ADD); this->PushOp(StackVM::ADDR_ADD);
} else if (op->is_intrinsic(intrinsic::tvm_api_load_arg)) { } else if (op->is_intrinsic(Call::null_handle)) {
this->PushOp(StackVM::PUSH_I64, 0);
} else if (op->is_intrinsic(intrinsic::tvm_struct_get)) {
CHECK_EQ(op->args.size(), 3U); CHECK_EQ(op->args.size(), 3U);
int kind = op->args[2].as<IntImm>()->value;
this->Push(op->args[0]); this->Push(op->args[0]);
this->Push(op->args[1]); const IntImm* index = op->args[1].as<IntImm>();
this->Push(op->args[2]); CHECK(index != nullptr);
if (op->type.is_handle()) { StackVM::Code code;
this->PushOp(StackVM::TVM_LOAD_ARG_HANDLE); code.op_code = StackVM::TVM_STRUCT_GET;
} else if (op->type.is_float()) { vm_.code.push_back(code);
this->PushOp(StackVM::TVM_LOAD_ARG_FP64); code.v_int = index->value;
} else if (op->type.is_int() || op->type.is_uint()) { vm_.code.push_back(code);
this->PushOp(StackVM::TVM_LOAD_ARG_INT64); code.v_int = kind;
} else { vm_.code.push_back(code);
LOG(FATAL) << "donot know how to handle type" << op->type; } else if (op->is_intrinsic(intrinsic::tvm_call_packed_lowered)) {
} CHECK_GE(op->args.size(), 5U);
} else if (op->is_intrinsic(intrinsic::tvm_array_get_field)) {
CHECK_EQ(op->args.size(), 2U);
this->Push(op->args[0]);
switch (op->args[1].as<IntImm>()->value) {
case intrinsic::kData: PushOp(StackVM::TVM_ARRAY_GET_DATA); break;
case intrinsic::kShape: PushOp(StackVM::TVM_ARRAY_GET_SHAPE); break;
case intrinsic::kStrides: PushOp(StackVM::TVM_ARRAY_GET_STRIDES); break;
case intrinsic::kNDim: PushOp(StackVM::TVM_ARRAY_GET_NDIM); break;
case intrinsic::kTypeCode: PushOp(StackVM::TVM_ARRAY_GET_TYPE_CODE); break;
case intrinsic::kTypeBits: PushOp(StackVM::TVM_ARRAY_GET_TYPE_BITS); break;
case intrinsic::kTypeLanes: PushOp(StackVM::TVM_ARRAY_GET_TYPE_LANES); break;
case intrinsic::kByteOffset: PushOp(StackVM::TVM_ARRAY_GET_BYTE_OFFSET); break;
default: LOG(FATAL) << "unknown field code";
}
} else if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
CHECK_GE(op->args.size(), 1U);
const StringImm* s = op->args[0].as<StringImm>(); const StringImm* s = op->args[0].as<StringImm>();
CHECK(s != nullptr) << "tvm_call_global expect first argument as function name"; CHECK(s != nullptr) << "tvm_call_global expect first argument as function name";
for (size_t i = 1; i < op->args.size(); ++i) { this->Push(op->args[1]);
this->Push(op->args[i]); this->Push(op->args[2]);
} int begin = op->args[3].as<IntImm>()->value;
int end = op->args[4].as<IntImm>()->value;
// find the fuction id. // find the fuction id.
const std::string& func_name = s->value; const std::string& func_name = s->value;
auto it = extern_fun_idmap_.find(func_name); auto it = extern_fun_idmap_.find(func_name);
...@@ -187,16 +163,39 @@ void CodeGenStackVM::VisitExpr_(const Call* op) { ...@@ -187,16 +163,39 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
vm_.extern_func_name.push_back(func_name); vm_.extern_func_name.push_back(func_name);
extern_fun_idmap_[func_name] = fid; extern_fun_idmap_[func_name] = fid;
} }
// get the argument type code. // CALL_PACKED_FUNC
std::vector<int> arg_type_codes; StackVM::Code code;
for (size_t i = 1; i < op->args.size(); ++i) { code.op_code = StackVM::CALL_PACKED_LOWERED;
Type t = op->args[i].type(); vm_.code.push_back(code);
int code = t.code(); code.v_int = fid;
int lanes = t.lanes(); vm_.code.push_back(code);
CHECK_EQ(lanes, 1); code.v_int = begin;
arg_type_codes.push_back(code); vm_.code.push_back(code);
code.v_int = end;
vm_.code.push_back(code);
} else if (op->is_intrinsic(intrinsic::tvm_stack_alloca)) {
CHECK_EQ(op->args.size(), 2U);
const std::string& type = op->args[0].as<StringImm>()->value;
const IntImm* num = op->args[1].as<IntImm>();
CHECK(num != nullptr);
static_assert(alignof(TVMValue) % alignof(TVMArray) == 0, "invariant");
static_assert(alignof(TVMValue) % alignof(tvm_index_t) == 0, "invariant");
size_t unit = sizeof(TVMValue);
size_t size = 0;
if (type == "shape") {
size = (num->value * sizeof(tvm_index_t) + unit - 1) / unit;
} else if (type == "arg_value") {
size = (num->value * sizeof(TVMValue) + unit - 1) / unit;
} else if (type == "arg_tcode") {
size = (num->value * sizeof(int) + unit - 1) / unit;
} else if (type == "array") {
size = (num->value * sizeof(TVMArray) + unit - 1) / unit;
} else {
LOG(FATAL) << "Unknown stack alloca type " << type;
} }
this->PushCallPacked(fid, arg_type_codes); // add stack size to be safe.
vm_.stack_size += size;
this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast<int>(size));
} else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) { } else if (op->is_intrinsic(intrinsic::tvm_handle_is_null)) {
CHECK_EQ(op->args.size(), 1U); CHECK_EQ(op->args.size(), 1U);
this->Push(op->args[0]); this->Push(op->args[0]);
...@@ -389,10 +388,26 @@ void CodeGenStackVM::VisitStmt_(const Block *op) { ...@@ -389,10 +388,26 @@ void CodeGenStackVM::VisitStmt_(const Block *op) {
if (op->rest.defined()) this->Push(op->rest); if (op->rest.defined()) this->Push(op->rest);
} }
void CodeGenStackVM::VisitStmt_(const Evaluate *op) { void CodeGenStackVM::VisitStmt_(const Evaluate *ev) {
if (is_const(op->value)) return; if (is_const(ev->value)) return;
this->Push(op->value); const Call* op = ev->value.as<Call>();
if (op && op->is_intrinsic(intrinsic::tvm_struct_set)) {
CHECK_EQ(op->args.size(), 4U);
this->Push(op->args[0]);
this->Push(op->args[3]);
const IntImm* index = op->args[1].as<IntImm>();
CHECK(index != nullptr);
StackVM::Code code;
code.op_code = StackVM::TVM_STRUCT_SET;
vm_.code.push_back(code);
code.v_int = index->value;
vm_.code.push_back(code);
code.v_int = op->args[2].as<IntImm>()->value;
vm_.code.push_back(code);
} else {
this->Push(ev->value);
this->PushOp(StackVM::POP); this->PushOp(StackVM::POP);
}
} }
void CodeGenStackVM::VisitStmt_(const IfThenElse *op) { void CodeGenStackVM::VisitStmt_(const IfThenElse *op) {
......
...@@ -56,13 +56,6 @@ class CodeGenStackVM ...@@ -56,13 +56,6 @@ class CodeGenStackVM
*/ */
int64_t PushOp(StackVM::OpCode opcode, int operand); int64_t PushOp(StackVM::OpCode opcode, int operand);
/*! /*!
* \brief Push a call packed function.
* \param fid The function id.
* \param arg_type_codes The type codes of arguments.
*/
void PushCallPacked(int fid,
const std::vector<int>& arg_type_codes);
/*!
* \brief Set the relative jump offset to be offset. * \brief Set the relative jump offset to be offset.
* \param operand_index The indexed returned by PushOp. * \param operand_index The indexed returned by PushOp.
* \param operand The operand to be set. * \param operand The operand to be set.
......
...@@ -55,24 +55,33 @@ class StackVM { ...@@ -55,24 +55,33 @@ class StackVM {
EQ_F64, EQ_F64,
LT_F64, LT_F64,
LE_F64, LE_F64,
// load operation
ADDR_LOAD_UINT32,
ADDR_LOAD_INT32,
ADDR_LOAD_INT64,
ADDR_LOAD_FP64,
ADDR_LOAD_HANDLE,
// store operations
// *(stack[sp - 1].v_andle) = stack[sp].v_int64
// sp = sp - 2;
ADDR_STORE_INT64,
/*! /*!
* \brief Quick routine to load uint32 from constant offset. * \brief Routine to load data from address with const offset.
* \code * \code
* stack[sp].v_int64 = ((uint32_t*)stack[sp].v_handle)[code[pc + 1].v_int]; * stack[sp].v_int64 = ((DType*)stack[sp].v_handle)[code[pc + 1].v_int];
* pc = pc + 2; * pc = pc + 2;
* \endcode * \endcode
*/ */
ARRAY_LOAD_UINT32, ARRAY_LOAD_UINT32,
ARRAY_LOAD_INT32,
ARRAY_LOAD_INT64,
ARRAY_LOAD_FP64,
ARRAY_LOAD_HANDLE,
ARRAY_LOAD_TVMVALUE,
/*!
* \brief Routine to store data from constant offset.
* \code
* ((DType*)stack[sp - 1].v_handle)[code[pc + 1].v_int] = stack[sp];
* pc = pc + 2;
* sp = sp - 2;
* \endcode
*/
ARRAY_STORE_UINT32,
ARRAY_STORE_INT32,
ARRAY_STORE_INT64,
ARRAY_STORE_FP64,
ARRAY_STORE_HANDLE,
ARRAY_STORE_TVMVALUE,
// logical ops // logical ops
NOT, NOT,
/*! /*!
...@@ -129,20 +138,6 @@ class StackVM { ...@@ -129,20 +138,6 @@ class StackVM {
*/ */
SELECT, SELECT,
/*! /*!
* \brief call an extern packed function
* \code
* num_args = stack[sp].v_int64;
* call_fid = code[pc + 1].v_int;
* f = extern_func[call_fid];
* int* type_codes = &(code[pc + 2].v_int)
* stack[sp - num_args] = f(&stack[sp - num_args], type_codes, num_args);
* sp = sp - num_args;
* // The type codes are hidden in the code space.
* pc = pc + 2 + num_args
* \endcode
*/
CALL_PACKED_FUNC,
/*!
* \brief Assert condition is true. * \brief Assert condition is true.
* \code * \code
* CHECK(stack[sp]) << str_data[code[pc + 1].v_int]; * CHECK(stack[sp]) << str_data[code[pc + 1].v_int];
...@@ -189,18 +184,56 @@ class StackVM { ...@@ -189,18 +184,56 @@ class StackVM {
* \code * \code
*/ */
ASSERT_SP, ASSERT_SP,
// Intrinsics for API function, /*!
TVM_LOAD_ARG_INT64, * \brief call an extern packed function
TVM_LOAD_ARG_FP64, * \code
TVM_LOAD_ARG_HANDLE, * value_stack = stack[sp - 1].v_handle;
TVM_ARRAY_GET_DATA, * type_stack = stack[sp - 0].v_handle;
TVM_ARRAY_GET_SHAPE, * call_fid = code[pc + 1].v_int;
TVM_ARRAY_GET_STRIDES, * begin = code[pc + 2].v_int;
TVM_ARRAY_GET_NDIM, * end = code[pc + 3].v_int;
TVM_ARRAY_GET_TYPE_CODE, * num_args = end - begin - 1;
TVM_ARRAY_GET_TYPE_BITS, * f = extern_func[call_fid];
TVM_ARRAY_GET_TYPE_LANES, * stack[sp - 1] = f(&value_stack[begin:end-1], type_stack[begin:end-1], num_args);
TVM_ARRAY_GET_BYTE_OFFSET * sp = sp - 1;
* // The type codes are hidden in the code space.
* pc = pc + 4
* \endcode
*/
CALL_PACKED_LOWERED,
// Allocate things on stack
/*!
* \brief allocate data from stack.
* \code
* num = code[pc + 1].v_int;
* void* addr = &stack[sp];
* sp = sp + num;
* stack[sp].v_handle = addr;
* pc = pc + 1;
* \endcode
*/
TVM_STACK_ALLOCA_BY_8BYTE,
/*!
* \brief get data from structure.
* \code
* index = code[pc + 1].v_int;
* field = code[pc + 2].v_int;
* stack[sp] = ((StructType*)stack[sp].v_handle)[index]->field;
* pc = pc + 3
* \endcode
*/
TVM_STRUCT_GET,
/*!
* \brief set data into structure.
* \code
* index = code[pc + 1].v_int;
* field = code[pc + 2].v_int;
* ((StructType*)stack[sp - 1].v_handle)[index]->field = stack[sp];
* pc = pc + 3
* sp = sp - 1
* \endcode
*/
TVM_STRUCT_SET
}; };
/*! \brief The code structure */ /*! \brief The code structure */
union Code { union Code {
...@@ -276,23 +309,23 @@ class StackVM { ...@@ -276,23 +309,23 @@ class StackVM {
*/ */
static OpCode GetLoad(TVMType t) { static OpCode GetLoad(TVMType t) {
CHECK_EQ(t.lanes, 1U); CHECK_EQ(t.lanes, 1U);
if (t.code == kHandle) return ADDR_LOAD_HANDLE; if (t.code == kHandle) return ARRAY_LOAD_HANDLE;
if (t.code == kInt) { if (t.code == kInt) {
switch (t.bits) { switch (t.bits) {
case 32 : return ADDR_LOAD_INT32; case 32 : return ARRAY_LOAD_INT32;
case 64 : return ADDR_LOAD_INT64; case 64 : return ARRAY_LOAD_INT64;
} }
} else if (t.code == kUInt) { } else if (t.code == kUInt) {
switch (t.bits) { switch (t.bits) {
case 32 : return ADDR_LOAD_UINT32; case 32 : return ARRAY_LOAD_UINT32;
} }
} else if (t.code == kFloat) { } else if (t.code == kFloat) {
switch (t.bits) { switch (t.bits) {
case 64 : return ADDR_LOAD_FP64; case 64 : return ARRAY_LOAD_FP64;
} }
} }
LOG(FATAL) << "Cannot load type " << t; LOG(FATAL) << "Cannot load type " << t;
return ADDR_LOAD_FP64; return ARRAY_LOAD_FP64;
} }
/*! /*!
* \brief Get store opcode for type t * \brief Get store opcode for type t
...@@ -301,13 +334,23 @@ class StackVM { ...@@ -301,13 +334,23 @@ class StackVM {
*/ */
static OpCode GetStore(TVMType t) { static OpCode GetStore(TVMType t) {
CHECK_EQ(t.lanes, 1U); CHECK_EQ(t.lanes, 1U);
if (t.code == kHandle) return ARRAY_STORE_HANDLE;
if (t.code == kInt) { if (t.code == kInt) {
switch (t.bits) { switch (t.bits) {
case 64 : return ADDR_STORE_INT64; case 32 : return ARRAY_STORE_INT32;
case 64 : return ARRAY_STORE_INT64;
}
} else if (t.code == kUInt) {
switch (t.bits) {
case 32 : return ARRAY_STORE_UINT32;
}
} else if (t.code == kFloat) {
switch (t.bits) {
case 64 : return ARRAY_STORE_FP64;
} }
} }
LOG(FATAL) << "Cannot store type " << t; LOG(FATAL) << "Cannot store type " << t;
return ADDR_LOAD_FP64; return ARRAY_STORE_FP64;
} }
friend std::ostream& operator<<(std::ostream& os, const StackVM& vm); // NOLINT(*) friend std::ostream& operator<<(std::ostream& os, const StackVM& vm); // NOLINT(*)
......
...@@ -85,6 +85,70 @@ inline Stmt MergeSeq(const std::vector<Stmt>& seq) { ...@@ -85,6 +85,70 @@ inline Stmt MergeSeq(const std::vector<Stmt>& seq) {
return body; return body;
} }
/*!
* \brief Get construct from struct
* \param dtype The data type.
* \param handle the struct handle.
* \param index the offset index.
* \param kind The data kind.
* \return the get expression.
*/
inline Expr TVMStructGet(
Type dtype, Var handle, int index,
intrinsic::TVMStructFieldKind kind) {
Array<Expr> args ={
handle,
make_const(Int(32), index),
make_const(Int(32), kind)};
return Call::make(dtype, intrinsic::tvm_struct_get, args, Call::PureIntrinsic);
}
/*!
* \brief Address of handle + offset
* \param handle the array handle.
* \param dtype The data type.
* \param offset the offset index.
*/
inline Expr AddressOffset(Var handle, Type dtype, int offset) {
return Call::make(
Handle(), Call::address_of,
{Load::make(dtype, handle, make_const(Int(32), offset))}, Call::PureIntrinsic);
}
/*!
* \brief Set value into struct.
* \param handle the struct handle.
* \param index the offset index.
* \param kind The data kind.
* \param value The value to be set.
* \return the set stmt.
*/
inline Stmt TVMStructSet(
Var handle, int index,
intrinsic::TVMStructFieldKind kind, Expr value) {
Array<Expr> args ={
handle,
make_const(Int(32), index),
make_const(Int(32), kind),
value};
return Evaluate::make(
Call::make(Int(32), intrinsic::tvm_struct_set, args, Call::Intrinsic));
}
/*!
* \brief Get the type that is passed around TVM PackedFunc API.
* \param t The original type.
* \return The corresponding API type.
*/
inline Type APIType(Type t) {
if (t.is_handle()) return t;
CHECK_EQ(t.lanes(), 1)
<< "Cannot pass vector type through packed API.";
if (t.is_uint() || t.is_int()) return Int(64);
CHECK(t.is_float());
return Float(64);
}
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
#endif // TVM_PASS_IR_UTIL_H_ #endif // TVM_PASS_IR_UTIL_H_
/*!
* Copyright (c) 2017 by Contributors
* Lower calls to packed function.
* \file lower_packed_call.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./ir_util.h"
namespace tvm {
namespace ir {
inline Expr ConstInt32(size_t index) {
CHECK_LE(index, std::numeric_limits<int>::max());
return make_const(Int(32), static_cast<int>(index));
}
inline Expr StackAlloca(std::string type, size_t num) {
Array<Expr> args = {StringImm::make(type), ConstInt32(num)};
return Call::make(Handle(), intrinsic::tvm_stack_alloca, args, Call::Intrinsic);
}
// Calculate the statistics of packed function.
// These information are needed during codegen.
class PackedCallBuilder : public IRMutator {
public:
Stmt Build(Stmt stmt) {
stack_shape_ = Var("stack_shape", Handle());
stack_array_ = Var("stack_array", Handle());
stack_value_ = Var("stack_value", Handle());
stack_tcode_ = Var("stack_tcode", Handle());
stmt = this->Mutate(stmt);
if (max_shape_stack_ != 0) {
stmt = LetStmt::make(
stack_shape_, StackAlloca("shape", max_shape_stack_), stmt);
}
if (max_array_stack_ != 0) {
stmt = LetStmt::make(
stack_array_, StackAlloca("array", max_array_stack_), stmt);
}
if (max_arg_stack_ != 0) {
stmt = LetStmt::make(
stack_value_, StackAlloca("arg_value", max_arg_stack_), stmt);
stmt = LetStmt::make(
stack_tcode_, StackAlloca("arg_tcode", max_arg_stack_), stmt);
}
return stmt;
}
Stmt Mutate(Stmt stmt) final {
stmt = IRMutator::Mutate(stmt);
CHECK_EQ(run_shape_stack_, 0);
CHECK_EQ(run_array_stack_, 0);
CHECK_EQ(run_arg_stack_, 0);
while (prep_seq_.size() != 0) {
stmt = Block::make(prep_seq_.back(), stmt);
prep_seq_.pop_back();
}
return stmt;
}
Stmt Mutate_(const AttrStmt* op, const Stmt &s) final {
if (op->attr_key == attr::device_context_id) {
CHECK(!device_id_.defined());
device_id_ = op->value;
return Mutate(op->body);
} else if (op->attr_key == attr::device_context_type) {
CHECK(!device_type_.defined());
device_type_ = op->value;
return Mutate(op->body);
} else {
return IRMutator::Mutate_(op, s);
}
}
Expr Mutate_(const Call* op, const Expr &e) final {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
return MakeCallPacked(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) {
return MakeShape(op, e);
} else if (op->is_intrinsic(intrinsic::tvm_stack_make_array)) {
return MakeArray(op, e);
} else {
return IRMutator::Mutate_(op, e);
}
}
Expr Convert(Type t, Expr e) {
if (e.type() != t) {
return Cast::make(t, e);
} else {
return e;
}
}
// call shape
Expr MakeShape(const Call* op, const Expr& e) {
size_t stack_begin = run_shape_stack_;
run_shape_stack_ += op->args.size();
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
for (size_t i = 0; i < op->args.size(); ++i) {
prep_seq_.emplace_back(
Store::make(stack_shape_, Convert(Int(64), op->args[i]),
ConstInt32(stack_begin +i)));
}
return AddressOffset(stack_shape_, Int(64), stack_begin);
}
// make array
Expr MakeArray(const Call* op, const Expr& e) {
size_t idx = run_array_stack_;
run_array_stack_ += 1;
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrData, op->args[0]));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrShape, op->args[1]));
Expr strides = op->args[2];
if (!strides.defined() || is_zero(strides)) {
strides = make_zero(Handle());
}
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrStrides, strides));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrNDim, op->args[3]));
Type dtype = op->args[4].type();
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrTypeCode,
make_const(UInt(8), static_cast<int>(dtype.code()))));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrTypeBits,
make_const(UInt(8), dtype.bits())));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrTypeLanes,
make_const(UInt(16), dtype.lanes())));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset,
Convert(Int(64), op->args[5])));
CHECK(device_type_.defined()) << "Unknown device type in current IR";
CHECK(device_id_.defined()) << "Unknown device id in current IR";
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId,
Convert(Int(32), device_id_)));
prep_seq_.emplace_back(
TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType,
Convert(Int(32), device_type_)));
return TVMStructGet(Handle(), stack_array_, idx, intrinsic::kArrAddr);
}
// call packled.
Expr MakeCallPacked(const Call* op, const Expr& e) {
size_t restore_shape_stack = run_shape_stack_;
size_t restore_array_stack = run_array_stack_;
size_t arg_stack_begin = run_arg_stack_;
run_arg_stack_ += op->args.size();
// Specially handle the buffer packed intrinsic
Expr expr = IRMutator::Mutate_(op, e);
op = expr.as<Call>();
for (size_t i = 1; i < op->args.size(); ++i) {
Expr stack_index = ConstInt32(arg_stack_begin + i - 1);
Expr arg = op->args[i];
Type t = arg.type();
Type api_type = APIType(t);
if (t != api_type) {
arg = Cast::make(api_type, arg);
}
prep_seq_.emplace_back(TVMStructSet(
stack_value_, static_cast<int>(arg_stack_begin + i - 1),
intrinsic::kTVMValueContent, arg));
int arg_tcode = api_type.code();
if (IsArrayHandle(arg)) arg_tcode = kArrayHandle;
prep_seq_.emplace_back(
Store::make(stack_tcode_,
ConstInt32(arg_tcode),
stack_index));
}
// UPDATE stack value
max_arg_stack_ = std::max(run_arg_stack_, max_arg_stack_);
max_shape_stack_ = std::max(run_shape_stack_, max_shape_stack_);
max_array_stack_ = std::max(run_array_stack_, max_array_stack_);
run_shape_stack_ = restore_shape_stack;
run_array_stack_ = restore_array_stack;
run_arg_stack_ = arg_stack_begin;
Array<Expr> packed_args = {
op->args[0],
stack_value_,
stack_tcode_,
ConstInt32(arg_stack_begin),
ConstInt32(arg_stack_begin + op->args.size() - 1)
};
return Call::make(
Int(32), intrinsic::tvm_call_packed_lowered,
packed_args, Call::Intrinsic);
}
private:
bool IsArrayHandle(const Expr& arg) {
// specially set array handle.
if (const Call* buf = arg.as<Call>()) {
if (buf->is_intrinsic(intrinsic::tvm_struct_get) &&
buf->args[2].as<IntImm>()->value == intrinsic::kArrAddr) {
return true;
}
}
return false;
}
// The prepration sequence to be emitted.
std::vector<Stmt> prep_seq_;
Expr device_type_;
Expr device_id_;
// Var handle for each stack.
Var stack_shape_;
Var stack_array_;
Var stack_tcode_;
Var stack_value_;
// The running statistics
uint64_t run_shape_stack_{0};
uint64_t run_array_stack_{0};
uint64_t run_arg_stack_{0};
// statistics of stacks
uint64_t max_shape_stack_{0};
uint64_t max_array_stack_{0};
uint64_t max_arg_stack_{0};
};
LoweredFunc LowerPackedCall(LoweredFunc f) {
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = PackedCallBuilder().Build(n->body);
return LoweredFunc(n);
}
} // namespace ir
} // namespace tvm
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
*/ */
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/buffer.h> #include <tvm/buffer.h>
#include <vector> #include <vector>
...@@ -15,11 +16,8 @@ ...@@ -15,11 +16,8 @@
namespace tvm { namespace tvm {
namespace ir { namespace ir {
inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMArrayFieldKind kind) { inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
return Call::make( return TVMStructGet(t, arr, 0, kind);
t, intrinsic::tvm_array_get_field,
{arr, IntImm::make(Int(32), kind)},
Call::PureIntrinsic);
} }
inline Stmt AssertNull(Var handle, std::string msg) { inline Stmt AssertNull(Var handle, std::string msg) {
...@@ -55,15 +53,25 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -55,15 +53,25 @@ LoweredFunc MakeAPI(Stmt body,
std::unordered_set<const Variable*> visited; std::unordered_set<const Variable*> visited;
// the handle data types // the handle data types
Map<Var, Expr> handle_data_type; Map<Var, Expr> handle_data_type;
// The device context
Var device_id, device_type;
// --------------------------- // ---------------------------
// local function defintiions // local function defintiions
// load i-th argument as type t // load i-th argument as type t
auto f_arg_value = [&](Type t, int i) { auto f_arg_value = [&](Type t, int i) {
Array<Expr> call_args{ Array<Expr> call_args{v_packed_args,
v_packed_args, v_packed_arg_type_ids, IntImm::make(Int(32), i)}; IntImm::make(Int(32), i),
return Call::make( IntImm::make(Int(32), intrinsic::kTVMValueContent)};
t, intrinsic::tvm_api_load_arg, call_args, // load 64 bit version
Type api_type = APIType(t);
Expr res = Call::make(
api_type, intrinsic::tvm_struct_get, call_args,
Call::PureIntrinsic); Call::PureIntrinsic);
// cast to the target version.
if (api_type != t) {
res = Cast::make(t, res);
}
return res;
}; };
// get declaration of argument i // get declaration of argument i
auto f_arg_decl = [&](int i) { auto f_arg_decl = [&](int i) {
...@@ -107,8 +115,32 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -107,8 +115,32 @@ LoweredFunc MakeAPI(Stmt body,
for (int i = 0; i < static_cast<int>(api_args.size()); ++i) { for (int i = 0; i < static_cast<int>(api_args.size()); ++i) {
Var v_arg = f_arg_decl(i); Var v_arg = f_arg_decl(i);
if (i < num_packed_args) { if (i < num_packed_args) {
// Value loads
seq_init.emplace_back(LetStmt::make( seq_init.emplace_back(LetStmt::make(
v_arg, f_arg_value(v_arg.type(), i), nop)); v_arg, f_arg_value(v_arg.type(), i), nop));
// type code checks
Var tcode(v_arg->name_hint + ".code", Int(32));
seq_init.emplace_back(LetStmt::make(
tcode, Load::make(
Int(32), v_packed_arg_type_ids, IntImm::make(Int(32), i)), nop));
Type t = v_arg.type();
if (t.is_handle()) {
std::ostringstream msg;
msg << "Expect argument " << i << " to be pointer";
seq_check.emplace_back(
AssertStmt::make(tcode == kHandle ||
tcode == kArrayHandle ||
tcode == kNull, msg.str()));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << "Expect argument " << i << " to be int";
seq_check.emplace_back(AssertStmt::make(tcode == kInt, msg.str()));
} else {
CHECK(t.is_float());
std::ostringstream msg;
msg << "Expect argument " << i << " to be float";
seq_check.emplace_back(AssertStmt::make(tcode == kFloat, msg.str()));
}
} else { } else {
args.push_back(v_arg); args.push_back(v_arg);
} }
...@@ -121,7 +153,7 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -121,7 +153,7 @@ LoweredFunc MakeAPI(Stmt body,
<< "api_args can only be Buffer or Var"; << "api_args can only be Buffer or Var";
Buffer buf(api_args[i].node_); Buffer buf(api_args[i].node_);
// dimension checks // dimension checks
Expr v_ndim = TVMArrayGet(tvm_ndim_type, v_arg, intrinsic::kNDim); Expr v_ndim = TVMArrayGet(tvm_ndim_type, v_arg, intrinsic::kArrNDim);
std::ostringstream ndim_err_msg; std::ostringstream ndim_err_msg;
ndim_err_msg << "arg_" << i ndim_err_msg << "arg_" << i
<< ".ndim is expected to equal " << ".ndim is expected to equal "
...@@ -135,15 +167,15 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -135,15 +167,15 @@ LoweredFunc MakeAPI(Stmt body,
Type dtype = buf->dtype; Type dtype = buf->dtype;
std::ostringstream type_err_msg; std::ostringstream type_err_msg;
type_err_msg << "arg" << i << ".dtype is expected to be " << dtype; type_err_msg << "arg" << i << ".dtype is expected to be " << dtype;
Expr cond = (TVMArrayGet(UInt(8), v_arg, intrinsic::kTypeCode) == Expr cond = (TVMArrayGet(UInt(8), v_arg, intrinsic::kArrTypeCode) ==
UIntImm::make(UInt(8), dtype.code()) && UIntImm::make(UInt(8), dtype.code()) &&
TVMArrayGet(UInt(8), v_arg, intrinsic::kTypeBits) == TVMArrayGet(UInt(8), v_arg, intrinsic::kArrTypeBits) ==
UIntImm::make(UInt(8), dtype.bits()) && UIntImm::make(UInt(8), dtype.bits()) &&
TVMArrayGet(UInt(16), v_arg, intrinsic::kTypeLanes) == TVMArrayGet(UInt(16), v_arg, intrinsic::kArrTypeLanes) ==
UIntImm::make(UInt(16), dtype.lanes())); UIntImm::make(UInt(16), dtype.lanes()));
seq_init.emplace_back(AssertStmt::make(cond, type_err_msg.str())); seq_init.emplace_back(AssertStmt::make(cond, type_err_msg.str()));
// Data Field // Data Field
if (f_push(buf->data, TVMArrayGet(Handle(), v_arg, intrinsic::kData), if (f_push(buf->data, TVMArrayGet(Handle(), v_arg, intrinsic::kArrData),
v_arg->name_hint + ".data")) { v_arg->name_hint + ".data")) {
Var vptr(buf->data); Var vptr(buf->data);
handle_data_type.Set(vptr, make_const(buf->dtype, 0)); handle_data_type.Set(vptr, make_const(buf->dtype, 0));
...@@ -152,20 +184,22 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -152,20 +184,22 @@ LoweredFunc MakeAPI(Stmt body,
Var v_shape(v_arg->name_hint + ".shape", Handle()); Var v_shape(v_arg->name_hint + ".shape", Handle());
handle_data_type.Set(v_shape, make_const(tvm_shape_type, 0)); handle_data_type.Set(v_shape, make_const(tvm_shape_type, 0));
seq_init.emplace_back(LetStmt::make( seq_init.emplace_back(LetStmt::make(
v_shape, TVMArrayGet(Handle(), v_arg, intrinsic::kShape), nop)); v_shape, TVMArrayGet(Handle(), v_arg, intrinsic::kArrShape), nop));
for (size_t k = 0; k < buf->shape.size(); ++k) { for (size_t k = 0; k < buf->shape.size(); ++k) {
std::ostringstream field_name; std::ostringstream field_name;
field_name << v_shape->name_hint << '[' << k << ']'; field_name << v_shape->name_hint << '[' << k << ']';
f_push(buf->shape[k], f_push(buf->shape[k],
cast(buf->shape[k].type(), cast(buf->shape[k].type(),
Load::make(tvm_shape_type, v_shape, IntImm::make(Int(32), k))), Load::make(tvm_shape_type, v_shape,
IntImm::make(Int(32), k))),
field_name.str()); field_name.str());
} }
// strides field // strides field
Var v_strides(v_arg->name_hint + ".strides", Handle()); Var v_strides(v_arg->name_hint + ".strides", Handle());
handle_data_type.Set(v_strides, make_const(tvm_shape_type, 0)); handle_data_type.Set(v_strides, make_const(tvm_shape_type, 0));
seq_init.emplace_back(LetStmt::make( seq_init.emplace_back(LetStmt::make(
v_strides, TVMArrayGet(Handle(), v_arg, intrinsic::kStrides), nop)); v_strides, TVMArrayGet(Handle(), v_arg, intrinsic::kArrStrides),
nop));
if (buf->strides.size() == 0) { if (buf->strides.size() == 0) {
std::ostringstream stride_err_msg; std::ostringstream stride_err_msg;
stride_err_msg << "arg_" << i << ".strides:" stride_err_msg << "arg_" << i << ".strides:"
...@@ -177,13 +211,22 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -177,13 +211,22 @@ LoweredFunc MakeAPI(Stmt body,
field_name << v_strides->name_hint << '[' << k << ']'; field_name << v_strides->name_hint << '[' << k << ']';
f_push(buf->strides[k], f_push(buf->strides[k],
cast(buf->shape[k].type(), cast(buf->shape[k].type(),
Load::make(tvm_shape_type, v_strides, IntImm::make(Int(32), k))), Load::make(tvm_shape_type, v_strides,
IntImm::make(Int(32), k))),
field_name.str()); field_name.str());
} }
} }
// Byte_offset field. // Byte_offset field.
f_push(buf->byte_offset, TVMArrayGet(UInt(64), v_arg, intrinsic::kByteOffset), f_push(buf->byte_offset,
TVMArrayGet(UInt(64), v_arg, intrinsic::kArrByteOffset),
v_arg->name_hint + ".byte_offset"); v_arg->name_hint + ".byte_offset");
// device info.
f_push(device_id,
TVMArrayGet(Int(32), v_arg, intrinsic::kArrDeviceId),
v_arg->name_hint + ".device_id");
f_push(device_type,
TVMArrayGet(Int(32), v_arg, intrinsic::kArrDeviceType),
v_arg->name_hint + ".device_type");
} }
} }
...@@ -192,6 +235,16 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -192,6 +235,16 @@ LoweredFunc MakeAPI(Stmt body,
n->args = args; n->args = args;
n->handle_data_type = handle_data_type; n->handle_data_type = handle_data_type;
n->is_packed_func = num_unpacked_args == 0; n->is_packed_func = num_unpacked_args == 0;
// Set device context
if (visited.count(device_id.get())) {
Expr node = StringImm::make("default");
CHECK(visited.count(device_type.get()));
seq_init.push_back(AttrStmt::make(
node, attr::device_context_id, device_id, nop));
seq_init.push_back(AttrStmt::make(
node, attr::device_context_type, device_type, nop));
}
n->body = MergeNest({seq_init, seq_check}, body); n->body = MergeNest({seq_init, seq_check}, body);
LoweredFunc f(n); LoweredFunc f(n);
Array<Var> undefined = UndefinedVars(f->body, f->args); Array<Var> undefined = UndefinedVars(f->body, f->args);
......
...@@ -70,6 +70,17 @@ class StorageAccessPatternFinder : public IRVisitor { ...@@ -70,6 +70,17 @@ class StorageAccessPatternFinder : public IRVisitor {
linear_seq_.push_back(e); linear_seq_.push_back(e);
} }
} }
void Visit_(const Evaluate* op) final {
scope_.push_back(StmtEntry());
// visit subexpr
IRVisitor::Visit_(op);
StmtEntry e = scope_.back();
scope_.pop_back();
if (e.access.size() != 0) {
e.stmt = op;
linear_seq_.push_back(e);
}
}
void Visit_(const Load* op) final { void Visit_(const Load* op) final {
// Add write access. // Add write access.
IRVisitor::Visit_(op); IRVisitor::Visit_(op);
...@@ -86,7 +97,7 @@ class StorageAccessPatternFinder : public IRVisitor { ...@@ -86,7 +97,7 @@ class StorageAccessPatternFinder : public IRVisitor {
// Directly reference to the variable count as a read. // Directly reference to the variable count as a read.
auto it = alloc_scope_level_.find(buf); auto it = alloc_scope_level_.find(buf);
if (it != alloc_scope_level_.end()) { if (it != alloc_scope_level_.end()) {
CHECK_LT(it->second, scope_.size()); CHECK_LT(it->second, scope_.size()) << " buf=" << buf->name_hint;
scope_[it->second].access.emplace_back( scope_[it->second].access.emplace_back(
AccessEntry(buf, Expr(), kOpaque, GetScope(buf))); AccessEntry(buf, Expr(), kOpaque, GetScope(buf)));
} }
......
...@@ -25,7 +25,8 @@ def test_add_pipeline(): ...@@ -25,7 +25,8 @@ def test_add_pipeline():
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0) fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0)
fsplits = tvm.ir_pass.SplitHostDevice(fapi) fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0])
def check_target(device, host="stackvm"): def check_target(device, host="stackvm"):
if not tvm.codegen.enabled(host): if not tvm.codegen.enabled(host):
......
...@@ -34,5 +34,78 @@ def test_add_pipeline(): ...@@ -34,5 +34,78 @@ def test_add_pipeline():
check_llvm() check_llvm()
def test_pack_buffer_simple():
nn = 1024
n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A')
def extern_generator(ins, outs):
"""Manually write the IR for the extern function, add pipeline."""
return tvm.call_packed("my_extern_array_func1", ins[0], outs[0])
C = tvm.extern(A.shape, [A], extern_generator, name='C')
s = tvm.create_schedule(C.op)
@tvm.register_func
def my_extern_array_func1(aa, bb):
aa.copyto(bb)
def check_target(target):
if not tvm.codegen.enabled(target):
return
# build and invoke the kernel.
f = tvm.build(s, [A, C], target)
ctx = tvm.cpu(0)
# launch the kernel.
n = nn
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
f(a, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy())
check_target("stackvm")
check_target("llvm")
def test_pack_buffer_intermediate():
nn = 1024
n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A')
B = tvm.compute((n,), lambda i: A[i] + 1, name="B")
def extern_generator(ins, outs):
"""Manually write the IR for the extern function, add pipeline."""
return tvm.call_packed("my_extern_array_func2", ins[0], outs[0])
C = tvm.extern(B.shape, [B], extern_generator, name='C')
s = tvm.create_schedule(C.op)
def check_target(target):
if not tvm.codegen.enabled(target):
return
# build and invoke the kernel.
f = tvm.build(s, [A, C], target)
ctx = tvm.cpu(0)
# launch the kernel.
n = nn
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
@tvm.register_func
def my_extern_array_func2(aa, bb):
assert aa.shape == a.shape
np.testing.assert_allclose(
aa.asnumpy(), a.asnumpy() + 1)
aa.copyto(bb)
f(a, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + 1)
check_target("llvm")
if __name__ == "__main__": if __name__ == "__main__":
test_pack_buffer_simple()
test_pack_buffer_intermediate()
test_add_pipeline() test_add_pipeline()
...@@ -9,7 +9,6 @@ def run_jit(fapi, check): ...@@ -9,7 +9,6 @@ def run_jit(fapi, check):
s = f.get_source() s = f.get_source()
check(f) check(f)
def test_stack_vm_basic(): def test_stack_vm_basic():
a = tvm.nd.array(np.zeros(10, dtype='float32')) a = tvm.nd.array(np.zeros(10, dtype='float32'))
@tvm.register_func @tvm.register_func
...@@ -21,6 +20,7 @@ def test_stack_vm_basic(): ...@@ -21,6 +20,7 @@ def test_stack_vm_basic():
Ab = tvm.decl_buffer((n, ), tvm.float32) Ab = tvm.decl_buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0])) stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0) fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
run_jit(fapi, lambda f: f(a)) run_jit(fapi, lambda f: f(a))
...@@ -42,6 +42,7 @@ def test_stack_vm_loop(): ...@@ -42,6 +42,7 @@ def test_stack_vm_loop():
stmt = ib.get() stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0) fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f): def check(f):
f(a) f(a)
...@@ -64,6 +65,7 @@ def test_stack_vm_cond(): ...@@ -64,6 +65,7 @@ def test_stack_vm_cond():
stmt = ib.get() stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 0) fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 0)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
def check(f): def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
......
...@@ -38,7 +38,8 @@ def test_add_pipeline(): ...@@ -38,7 +38,8 @@ def test_add_pipeline():
px, x = s[C].split(C.op.axis[0], nparts=1) px, x = s[C].split(C.op.axis[0], nparts=1)
s[C].bind(px, tvm.thread_axis("pipeline")) s[C].bind(px, tvm.thread_axis("pipeline"))
fapi = lower(s, [A, B, C], "myadd") fapi = lower(s, [A, B, C], "myadd")
fsplits = tvm.ir_pass.SplitHostDevice(fapi) fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0])
print(fsplits[1].body) print(fsplits[1].body)
print("------") print("------")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment