Unverified Commit 75e936e1 by Tianqi Chen Committed by GitHub

[REFACTOR][TIR] Migrate most of low-level build to use the Pass Manager. (#5225)

* [REFACTOR][TIR] Migrate most of low-level build to use the Pass Manager.

- SplitHostDevice
- ThreadSync
- BindDevice
- LowerThreadAllreduce
- Provide a temp fix for printing IRModule with PrimFunc before the formal text printer.

* Address comments, fix tests.

* Fix relay tests

* Explicit move
parent 9b274cbb
...@@ -47,19 +47,19 @@ enum class CallingConv : int { ...@@ -47,19 +47,19 @@ enum class CallingConv : int {
*/ */
kDefault = 0, kDefault = 0,
/*! /*!
* \brief PackedFunc that exposes a CPackedFunc signature.
*
* - Calling by PackedFunc calling convention.
* - Implementation: Expose a function with the CPackedFunc signature.
*/
kCPackedFunc = 1,
/*!
* \brief Device kernel launch * \brief Device kernel launch
* *
* - Call by PackedFunc calling convention. * - Call by PackedFunc calling convention.
* - Implementation: defined by device runtime(e.g. runtime/cuda) * - Implementation: defined by device runtime(e.g. runtime/cuda)
*/ */
kDeviceKernelLaunch = 2, kDeviceKernelLaunch = 2,
/*!
* \brief PackedFunc that exposes a CPackedFunc signature.
*
* - Calling by PackedFunc calling convention.
* - Implementation: Expose a function with the CPackedFunc signature.
*/
kCPackedFunc = 3,
}; };
/*! /*!
......
...@@ -324,6 +324,8 @@ class IRModule : public ObjectRef { ...@@ -324,6 +324,8 @@ class IRModule : public ObjectRef {
/*! \brief Declare the container type. */ /*! \brief Declare the container type. */
using ContainerType = IRModuleNode; using ContainerType = IRModuleNode;
// allow copy on write.
TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode);
}; };
/*! /*!
......
...@@ -49,6 +49,16 @@ struct ExprDeepEqual { ...@@ -49,6 +49,16 @@ struct ExprDeepEqual {
public: public:
TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const; TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
}; };
/*!
* \brief Find undefined vars in the statment.
* \param stmt The function to be checked.
* \param defs The vars that is defined.
* \return Array of undefined vars.
*/
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
#endif // TVM_TIR_ANALYSIS_H_ #endif // TVM_TIR_ANALYSIS_H_
...@@ -407,56 +407,6 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -407,56 +407,6 @@ LoweredFunc MakeAPI(Stmt body,
bool is_restricted); bool is_restricted);
/*! /*!
* \brief Bind the device type of host function to be device_type.
* \param func The function to be binded.
* \param device_type The device type to be binded.
* \return The binded function.
*/
LoweredFunc BindDeviceType(LoweredFunc func,
int device_type);
/*!
* \brief Find undefined vars in the statment.
* \param stmt The function to be checked.
* \param defs The vars that is defined.
* \return Array of undefined vars.
*/
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
/*!
* \brief Split the function into a host function and device functions.
* \param func The function to be splitted.
*
* \return Array of functions, the first one is host function,
* the others are device functions.
*/
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
/*!
* \brief Insert sync between parallel read/write of shared buffers.
*
* \param stmt The stmt to be trasnformed.
* \param storage_scope The storage scope considered.
*/
LoweredFunc ThreadSync(LoweredFunc stmt, std::string storage_scope);
/*!
* \brief Lower cross thread alleduce in the stmt.
* \param f The device function to be lowered.
* \param warp_size the size of warp where no sync is needed.
* \return Transformed function.
*/
LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
/*!
* \brief Lower warp memory in stmt.
* \param f The device function to be lowered.
* \param warp_size the size of warp where no sync is needed.
* this function will only take in effect if warp_size is bigger than one.
* \return Transformed function.
*/
LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);
/*!
* \brief Remap the thread axis * \brief Remap the thread axis
* *
* This can be used to get equivalent program which uses * This can be used to get equivalent program which uses
...@@ -471,26 +421,6 @@ LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size); ...@@ -471,26 +421,6 @@ LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);
LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map); LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
/*! /*!
* \brief Lower packed function call.
* \param f The function to be lowered.
* \return Transformed function.
*/
LoweredFunc LowerTVMBuiltin(LoweredFunc f);
/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
* to avoid pointer casting in backend when possible.
*
* \note implemeneted in storage_rewrite.cc
* \param f The function to be trasnformed
* \return Transformed function.
*/
LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
/*!
* \brief Rewrite the pointer content type of arguments, * \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use * as well as Alloc internal to the function to use
* the most frequently accessed type for load/store * the most frequently accessed type for load/store
...@@ -514,14 +444,6 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f); ...@@ -514,14 +444,6 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f);
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target); LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
/*! /*!
* \brief Infer the TensorCore fragment infomation using tensor intrinsics
*
* \param f The device function to be lowered.
* \return Transformed function.
*/
LoweredFunc InferFragment(LoweredFunc f);
/*!
* \brief Verify if memory accesses are legal for a specific target device type. * \brief Verify if memory accesses are legal for a specific target device type.
* *
* In the case that tgt is cuda, if not all workload is bound with * In the case that tgt is cuda, if not all workload is bound with
......
...@@ -59,6 +59,21 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< ...@@ -59,6 +59,21 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const tvm::Array<tvm::PrimExpr>& required); const tvm::Array<tvm::PrimExpr>& required);
/*! /*!
* \brief Bind the device type ofthe function to be
* the device_type specified in the target attribute.
*
* \return The pass.
*/
TVM_DLL Pass BindDeviceType();
/*!
* \brief Split the function into a host function and device functions.
*
* \return The pass.
*/
TVM_DLL Pass SplitHostDevice();
/*!
* \brief skip assert stmt. * \brief skip assert stmt.
* *
* \return The pass. * \return The pass.
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name
"""The build utils in python. """The build utils in python.
This module provides the functions to transform schedule to This module provides the functions to transform schedule to
...@@ -25,6 +27,7 @@ import tvm.tir ...@@ -25,6 +27,7 @@ import tvm.tir
from tvm.runtime import ndarray from tvm.runtime import ndarray
from tvm.ir import container from tvm.ir import container
from tvm.ir import CallingConv
from tvm.target import codegen, BuildConfig from tvm.target import codegen, BuildConfig
from tvm.tir import ir_pass from tvm.tir import ir_pass
from tvm.tir.stmt import LoweredFunc from tvm.tir.stmt import LoweredFunc
...@@ -222,75 +225,59 @@ def _build_for_device(flist, target, target_host): ...@@ -222,75 +225,59 @@ def _build_for_device(flist, target, target_host):
mdev : tvm.module mdev : tvm.module
A module that contains device code. A module that contains device code.
""" """
@tvm.tir.transform.prim_func_pass(opt_level=0)
class BindTarget:
def __init__(self, target):
self.target = target
# pylint: disable=unused-argument
def transform_function(self, func, mod, ctx):
return func.with_attr("target", self.target)
target = _target.create(target) target = _target.create(target)
target_host = _target.create(target_host)
device_type = ndarray.context(target.target_name, 0).device_type device_type = ndarray.context(target.target_name, 0).device_type
fhost = []
fdevice = []
for func in flist: for func in flist:
if not ir_pass.VerifyMemory(func, device_type): if not ir_pass.VerifyMemory(func, device_type):
raise ValueError( raise ValueError(
"Direct host side access to device memory is detected in %s. " "Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name) "Did you forget to bind?" % func.name)
if func.func_type == LoweredFunc.MixedFunc:
if BuildConfig.current().detect_global_barrier:
func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared")
func = ir_pass.ThreadSync(func, "warp")
func = ir_pass.InferFragment(func)
warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = list(ir_pass.SplitHostDevice(func))
fhost.append(fsplits[0])
for x in fsplits[1:]:
fdevice.append(x)
elif func.func_type == LoweredFunc.HostFunc:
fhost.append(func)
elif func.func_type == LoweredFunc.DeviceFunc:
fdevice.append(func)
else:
raise ValueError("unknown function type %d" % func.func_type)
if "gpu" in target.keys and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do "
"bind?" % target)
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost] mod_mixed = tvm.testing.LoweredFuncsToIRModule(flist)
opt_mixed = [tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))]
if BuildConfig.current().detect_global_barrier:
opt_mixed += [tvm.tir.transform.ThreadSync("global")]
opt_mixed += [tvm.tir.transform.ThreadSync("shared"),
tvm.tir.transform.ThreadSync("warp"),
tvm.tir.transform.InferFragment(),
tvm.tir.transform.LowerThreadAllreduce(),
tvm.tir.transform.BindDeviceType(),
tvm.tir.transform.SplitHostDevice()]
mod_mixed = tvm.ir.transform.Sequential(opt_mixed)(mod_mixed)
if device_type == ndarray.cpu(0).device_type and target_host == target:
assert not fdevice
target_host = _target.create(target_host)
# device optimizations # device optimizations
mod_dev = tvm.testing.LoweredFuncsToIRModule(fdevice)
opt_device = tvm.ir.transform.Sequential( opt_device = tvm.ir.transform.Sequential(
[BindTarget(target), [tvm.tir.transform.Filter(
lambda f: "calling_conv" in f.attrs and
f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH),
tvm.tir.transform.LowerWarpMemory(), tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin()]) tvm.tir.transform.LowerIntrin()])
mod_dev = opt_device(mod_dev) mod_dev = opt_device(mod_mixed)
# host optimizations # host optimizations
mod_host = tvm.testing.LoweredFuncsToIRModule(fhost)
opt_host = tvm.ir.transform.Sequential( opt_host = tvm.ir.transform.Sequential(
[BindTarget(target_host), [tvm.tir.transform.Filter(
lambda f: "calling_conv" not in f.attrs or
f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH),
tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
tvm.tir.transform.LowerTVMBuiltin(), tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin(), tvm.tir.transform.LowerIntrin(),
tvm.tir.transform.CombineContextCall()]) tvm.tir.transform.CombineContextCall()])
mod_host = opt_host(mod_host) mod_host = opt_host(mod_mixed)
if device_type == ndarray.cpu(0).device_type and target_host == target:
assert len(mod_dev.functions) == 0
if "gpu" in target.keys and len(mod_dev.functions) == 0:
warnings.warn(
"Specified target %s, but cannot find device code, did you do "
"bind?" % target)
rt_mod_dev = codegen.build_module(mod_dev, target) if fdevice else None rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None
return mod_host, rt_mod_dev return mod_host, rt_mod_dev
......
...@@ -23,7 +23,7 @@ from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType ...@@ -23,7 +23,7 @@ from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation from .type_relation import TypeCall, TypeRelation
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
from .function import BaseFunc from .function import CallingConv, BaseFunc
from .adt import Constructor, TypeData from .adt import Constructor, TypeData
from .module import IRModule from .module import IRModule
from .attrs import Attrs, DictAttrs, make_node from .attrs import Attrs, DictAttrs, make_node
......
...@@ -15,10 +15,18 @@ ...@@ -15,10 +15,18 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Function defintiions.""" """Function defintiions."""
from enum import IntEnum
from .expr import RelayExpr from .expr import RelayExpr
from . import _ffi_api from . import _ffi_api
class CallingConv(IntEnum):
"""Possible kinds of calling conventions."""
DEFAULT = 0
C_PACKED_FUNC = 1
DEVICE_KERNEL_LAUNCH = 2
class BaseFunc(RelayExpr): class BaseFunc(RelayExpr):
"""Base class of all functions.""" """Base class of all functions."""
@property @property
......
...@@ -60,7 +60,6 @@ class IRModule(Node): ...@@ -60,7 +60,6 @@ class IRModule(Node):
type_definitions = mapped_type_defs type_definitions = mapped_type_defs
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions) self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)
def __setitem__(self, var, val): def __setitem__(self, var, val):
"""Add a mapping to the module. """Add a mapping to the module.
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# under the License. # under the License.
"""TIR specific function pass support.""" """TIR specific function pass support."""
import inspect import inspect
import types
import functools import functools
import tvm._ffi import tvm._ffi
...@@ -142,7 +143,7 @@ def prim_func_pass(pass_func=None, opt_level=None, name=None, required=None): ...@@ -142,7 +143,7 @@ def prim_func_pass(pass_func=None, opt_level=None, name=None, required=None):
return _wrap_class_function_pass(pass_arg, info) return _wrap_class_function_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass") raise TypeError("pass_func must be a callable for Module pass")
return _ffi_api.MakeFunctionPass(pass_arg, info) return _ffi_api.CreatePrimFuncPass(pass_arg, info)
if pass_func: if pass_func:
return create_function_pass(pass_func) return create_function_pass(pass_func)
......
...@@ -17,6 +17,70 @@ ...@@ -17,6 +17,70 @@
"""Wrapping existing transformations.""" """Wrapping existing transformations."""
# pylint: disable=invalid-name # pylint: disable=invalid-name
from . import _ffi_api from . import _ffi_api
from . import function_pass as _fpass
def Apply(ftransform):
"""Apply ftransform to each function in the Module.
This function is a thin wrapper around tvm.tir.transform.prim_func_pass
Parameters
----------
ftransform: tvm.tir.PrimFunc -> tvm.tir.PrimFunc
The transformation pass.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
# pylint: disable=unused-argument
def _transform(func, mod, ctx):
return ftransform(func)
return _fpass.prim_func_pass(_transform, opt_level=0)
def Filter(fcond):
"""Filter functions by the calling convention attribute.
Parameters
----------
fcond : tvm.tir.PrimFunc -> bool
The condition of the filtering.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
# pylint: disable=unused-argument
def _transform(func, mod, ctx):
return func if fcond(func) else None
return _fpass.prim_func_pass(_transform, opt_level=0)
def BindDeviceType():
"""Bind the device type of the function to be
the device_type specified in the target attribute.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.BindDeviceType()
def SplitHostDevice():
"""Split the function into a host function and device functions.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.SplitHostDevice()
def SkipAssert(): def SkipAssert():
......
...@@ -185,75 +185,50 @@ transform::Pass BindTarget(Target target) { ...@@ -185,75 +185,50 @@ transform::Pass BindTarget(Target target) {
} }
template<typename FCond>
transform::Pass FilterBy(FCond fcond) {
auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
if (fcond(f)) {
return f;
} else {
return tir::PrimFunc(nullptr);
}
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "FilterBy", {});
}
std::pair<IRModule, IRModule> std::pair<IRModule, IRModule>
split_dev_host_funcs(const Array<LoweredFunc>& funcs, split_dev_host_funcs(const Array<LoweredFunc>& funcs,
const Target& target, const Target& target,
const Target& target_host, const Target& target_host,
const BuildConfig& config) { const BuildConfig& config) {
std::unordered_set<std::string> all_names;
for (const auto& x : funcs) {
CHECK(all_names.count(x->name) == 0)
<< "Duplicate function name " << x->name;
all_names.insert(x->name);
}
Array<LoweredFunc> fhost;
Array<LoweredFunc> fdevice;
for (const auto& x : funcs) { for (const auto& x : funcs) {
CHECK(tir::VerifyMemory(x, target->device_type)) CHECK(tir::VerifyMemory(x, target->device_type))
<< "Direct host side access to device memory is detected in " << "Direct host side access to device memory is detected in "
<< x->func_name() << ". Did you forget to bind?"; << x->func_name() << ". Did you forget to bind?";
if (x->func_type == tir::kMixedFunc) {
auto func = x;
if (config->detect_global_barrier) {
func = tir::ThreadSync(func, "global");
}
func = tir::ThreadSync(func, "shared");
func = tir::ThreadSync(func, "warp");
func = tir::InferFragment(func);
func = tir::LowerThreadAllreduce(func, target->thread_warp_size);
auto fsplits = tir::SplitHostDevice(func);
fhost.push_back(fsplits[0]);
for (auto f = fsplits.begin() + 1; f != fsplits.end(); ++f) {
fdevice.push_back(*f);
}
} else if (x->func_type == tir::kHostFunc) {
fhost.push_back(x);
} else if (x->func_type == tir::kDeviceFunc) {
fdevice.push_back(x);
} else {
LOG(FATAL) << "unknown function type " << x->func_type;
}
}
auto keys = target->keys();
bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
if (target_is_gpu && fdevice.size() == 0) {
LOG(WARNING) << "Specified target "
<< target->str()
<< " but cannot find device code. Did you forget to bind?";
} }
IRModule mod_mixed = codegen::ToIRModule(funcs);
if (target->device_type == target::llvm()->device_type && Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target)};
target_host == target) { if (config->detect_global_barrier) {
CHECK(fdevice.empty()) << "No device code should be generated when target " mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
<< "and host_target are both llvm target."
<< "\n";
}
for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i];
func = tir::BindDeviceType(func, target->device_type);
fhost.Set(i, func);
} }
mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
mixed_pass_list.push_back(tir::transform::InferFragment());
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
mixed_pass_list.push_back(tir::transform::BindDeviceType());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
auto opt_mixed = transform::Sequential(mixed_pass_list);
mod_mixed = opt_mixed(std::move(mod_mixed));
// host pipeline
auto mhost = codegen::ToIRModule(fhost);
auto host_pass_list = { auto host_pass_list = {
FilterBy([](const tir::PrimFunc& f) {
int64_t value = f->GetAttr<Integer>(tvm::attr::kCallingConv, 0)->value;
return value != static_cast<int>(CallingConv::kDeviceKernelLaunch);
}),
BindTarget(target_host), BindTarget(target_host),
tir::transform::LowerTVMBuiltin(), tir::transform::LowerTVMBuiltin(),
tir::transform::LowerIntrin(), tir::transform::LowerIntrin(),
...@@ -261,18 +236,38 @@ split_dev_host_funcs(const Array<LoweredFunc>& funcs, ...@@ -261,18 +236,38 @@ split_dev_host_funcs(const Array<LoweredFunc>& funcs,
tir::transform::CombineContextCall(), tir::transform::CombineContextCall(),
}; };
auto opt_host = transform::Sequential(host_pass_list); auto opt_host = transform::Sequential(host_pass_list);
mhost = opt_host(mhost); auto mhost = opt_host(mod_mixed);
// device pipeline // device pipeline
auto mdevice = codegen::ToIRModule(fdevice);
auto device_pass_list = { auto device_pass_list = {
FilterBy([](const tir::PrimFunc& f) {
int64_t value = f->GetAttr<Integer>(tvm::attr::kCallingConv, 0)->value;
return value == static_cast<int>(CallingConv::kDeviceKernelLaunch);
}),
BindTarget(target), BindTarget(target),
tir::transform::LowerWarpMemory(), tir::transform::LowerWarpMemory(),
tir::transform::LowerIntrin(), tir::transform::LowerIntrin(),
tir::transform::LowerDeviceStorageAccessInfo(), tir::transform::LowerDeviceStorageAccessInfo(),
}; };
auto opt_device = transform::Sequential(device_pass_list); auto opt_device = transform::Sequential(device_pass_list);
mdevice = opt_device(mdevice); auto mdevice = opt_device(mod_mixed);
// some final misc checks.
auto keys = target->keys();
bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
if (target_is_gpu && mdevice->functions.size() == 0) {
LOG(WARNING) << "Specified target "
<< target->str()
<< " but cannot find device code. Did you forget to bind?";
}
if (target->device_type == target::llvm()->device_type &&
target_host == target) {
CHECK(mdevice->functions.empty())
<< "No device code should be generated when target "
<< "and host_target are both llvm target."
<< "\n";
}
return {mhost, mdevice}; return {mhost, mdevice};
} }
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
*/ */
#include <tvm/ir/type_functor.h> #include <tvm/ir/type_functor.h>
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <tvm/tir/function.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
#include "doc.h" #include "doc.h"
...@@ -434,6 +435,10 @@ class RelayTextPrinter : ...@@ -434,6 +435,10 @@ class RelayTextPrinter :
Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) { Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) {
if (auto* n = base_func.as<relay::FunctionNode>()) { if (auto* n = base_func.as<relay::FunctionNode>()) {
return PrintFunc(prefix, GetRef<relay::Function>(n)); return PrintFunc(prefix, GetRef<relay::Function>(n));
} else if (auto* n = base_func.as<tir::PrimFuncNode>()) {
std::ostringstream os;
os << GetRef<tir::PrimFunc>(n);
return Doc::RawText(os.str());
} else { } else {
// def @xyz = meta['ExternalFunc'][id] // def @xyz = meta['ExternalFunc'][id]
Doc doc; Doc doc;
...@@ -455,8 +460,9 @@ class RelayTextPrinter : ...@@ -455,8 +460,9 @@ class RelayTextPrinter :
} }
// functions // functions
for (const auto& kv : mod->functions) { for (const auto& kv : mod->functions) {
dg_ = DependencyGraph::Create(&arena_, kv.second); if (kv.second.as<relay::FunctionNode>()) {
dg_ = DependencyGraph::Create(&arena_, kv.second);
}
if (counter++ != 0) { if (counter++ != 0) {
doc << Doc::NewLine(); doc << Doc::NewLine();
} }
......
...@@ -50,9 +50,10 @@ tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) { ...@@ -50,9 +50,10 @@ tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) {
Map<tir::Var, PrimExpr> remap_vars; Map<tir::Var, PrimExpr> remap_vars;
for (auto var : from->args) { for (auto var : from->args) {
if (from->handle_data_type.count(var)) { auto it = from->handle_data_type.find(var);
if (it != from->handle_data_type.end()) {
tir::Var new_var(var->name_hint, tir::Var new_var(var->name_hint,
PointerType(PrimType(var->dtype))); PointerType(PrimType((*it).second->dtype)));
args.push_back(new_var); args.push_back(new_var);
remap_vars.Set(var, new_var); remap_vars.Set(var, new_var);
} else { } else {
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include "codegen_cpu.h" #include "codegen_cpu.h"
......
...@@ -108,8 +108,13 @@ IRModule PrimFuncPassNode::operator()(const IRModule& mod, ...@@ -108,8 +108,13 @@ IRModule PrimFuncPassNode::operator()(const IRModule& mod,
updates.push_back({it.first, updated_func}); updates.push_back({it.first, updated_func});
} }
} }
// automatic removal of None
for (const auto& pair : updates) { for (const auto& pair : updates) {
updated_mod->Add(pair.first, pair.second, true); if (pair.second.defined()) {
updated_mod->Add(pair.first, pair.second, true);
} else {
updated_mod->Remove(pair.first);
}
} }
pass_ctx.Trace(updated_mod, pass_info, false); pass_ctx.Trace(updated_mod, pass_info, false);
return updated_mod; return updated_mod;
......
...@@ -128,10 +128,7 @@ REGISTER_PASS(VectorizeLoop); ...@@ -128,10 +128,7 @@ REGISTER_PASS(VectorizeLoop);
REGISTER_PASS(SkipVectorize); REGISTER_PASS(SkipVectorize);
REGISTER_PASS(UnrollLoop); REGISTER_PASS(UnrollLoop);
REGISTER_PASS(InjectCopyIntrin); REGISTER_PASS(InjectCopyIntrin);
REGISTER_PASS(ThreadSync);
REGISTER_PASS(MakeAPI); REGISTER_PASS(MakeAPI);
REGISTER_PASS(BindDeviceType);
REGISTER_PASS(SplitHostDevice);
REGISTER_PASS(StorageRewrite); REGISTER_PASS(StorageRewrite);
REGISTER_PASS(CoProcSync); REGISTER_PASS(CoProcSync);
REGISTER_PASS(LowerStorageAccessInfo); REGISTER_PASS(LowerStorageAccessInfo);
...@@ -141,7 +138,6 @@ REGISTER_PASS(InjectDoubleBuffer); ...@@ -141,7 +138,6 @@ REGISTER_PASS(InjectDoubleBuffer);
REGISTER_PASS(LoopPartition); REGISTER_PASS(LoopPartition);
REGISTER_PASS(RemoveNoOp); REGISTER_PASS(RemoveNoOp);
REGISTER_PASS(LiftAttrScope); REGISTER_PASS(LiftAttrScope);
REGISTER_PASS(LowerThreadAllreduce);
REGISTER_PASS(RemapThreadAxis); REGISTER_PASS(RemapThreadAxis);
REGISTER_PASS(LowerCustomDatatypes); REGISTER_PASS(LowerCustomDatatypes);
REGISTER_PASS(VerifyMemory); REGISTER_PASS(VerifyMemory);
...@@ -150,7 +146,6 @@ REGISTER_PASS(DecorateDeviceScope); ...@@ -150,7 +146,6 @@ REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(VerifyCompactBuffer);
REGISTER_PASS(HoistIfThenElse); REGISTER_PASS(HoistIfThenElse);
REGISTER_PASS(InferFragment)
REGISTER_PASS(NarrowDataType); REGISTER_PASS(NarrowDataType);
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -218,69 +218,6 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -218,69 +218,6 @@ LoweredFunc MakeAPI(Stmt body,
return f; return f;
} }
class DeviceTypeBinder: public StmtExprMutator {
public:
explicit DeviceTypeBinder(int device_type)
: device_type_(device_type) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::device_context_type) {
if (const VarNode* var = op->value.as<VarNode>()) {
var_ = var;
PrimExpr value = make_const(op->value.dtype(), device_type_);
Stmt body = StmtExprMutator::VisitStmt_(op);
var_ = nullptr;
std::ostringstream os;
os << "device_type need to be " << device_type_;
return AssertStmtNode::make(op->value == value, os.str(), body);
}
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const IfThenElseNode* op) final {
// eager simplify if guard.
Stmt res = StmtExprMutator::VisitStmt_(op);
op = res.as<IfThenElseNode>();
if (is_zero(op->condition)) {
if (op->else_case.defined()) return op->else_case;
return EvaluateNode::make(0);
}
if (is_one(op->condition)) {
return op->then_case;
}
return res;
}
PrimExpr VisitExpr_(const NENode* op) final {
// eager check NE for device check
PrimExpr res = StmtExprMutator::VisitExpr_(op);
op = res.as<NENode>();
if (tir::ExprDeepEqual()(op->a, op->b)) {
return make_const(op->dtype, false);
}
return res;
}
PrimExpr VisitExpr_(const VarNode* op) final {
if (op == var_) {
return make_const(op->dtype, device_type_);
} else {
return GetRef<PrimExpr>(op);
}
}
public:
const VarNode* var_{nullptr};
int device_type_;
};
LoweredFunc BindDeviceType(LoweredFunc f,
int device_type) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = DeviceTypeBinder(device_type)(n->body);
return LoweredFunc(n);
}
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file bind_device_type.cc
* \brief Bind the device type according to the target field.
*/
#include <tvm/ir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/analysis.h>
#include <tvm/target/target.h>
#include <tvm/runtime/registry.h>
namespace tvm {
namespace tir {
class DeviceTypeBinder: public StmtExprMutator {
public:
explicit DeviceTypeBinder(int device_type)
: device_type_(device_type) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::device_context_type) {
if (const VarNode* var = op->value.as<VarNode>()) {
var_ = var;
PrimExpr value = make_const(op->value.dtype(), device_type_);
Stmt body = StmtExprMutator::VisitStmt_(op);
var_ = nullptr;
std::ostringstream os;
os << "device_type need to be " << device_type_;
return AssertStmtNode::make(op->value == value, os.str(), body);
}
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const IfThenElseNode* op) final {
// eager simplify if guard.
Stmt res = StmtExprMutator::VisitStmt_(op);
op = res.as<IfThenElseNode>();
if (is_zero(op->condition)) {
if (op->else_case.defined()) return op->else_case;
return EvaluateNode::make(0);
}
if (is_one(op->condition)) {
return op->then_case;
}
return res;
}
PrimExpr VisitExpr_(const NENode* op) final {
// eager check NE for device check
PrimExpr res = StmtExprMutator::VisitExpr_(op);
op = res.as<NENode>();
if (tir::ExprDeepEqual()(op->a, op->b)) {
return make_const(op->dtype, false);
}
return res;
}
PrimExpr VisitExpr_(const VarNode* op) final {
if (op == var_) {
return make_const(op->dtype, device_type_);
} else {
return GetRef<PrimExpr>(op);
}
}
public:
const VarNode* var_{nullptr};
int device_type_;
};
namespace transform {
Pass BindDeviceType() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "BindDeviceType: Require the target attribute";
n->body = DeviceTypeBinder(target->device_type)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.BindDeviceType", {});
}
TVM_REGISTER_GLOBAL("tir.transform.BindDeviceType")
.set_body_typed(BindDeviceType);
} // namespace transform
} // namespace tir
} // namespace tvm
...@@ -340,14 +340,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { ...@@ -340,14 +340,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
std::unordered_map<const VarNode *, Stmt> alloc_remap_; std::unordered_map<const VarNode *, Stmt> alloc_remap_;
}; };
LoweredFunc
LowerThreadAllreduce(LoweredFunc f, int warp_size) {
CHECK_NE(f->func_type, kHostFunc);
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = ThreadAllreduceBuilder(warp_size)(n->body);
return LoweredFunc(n);
}
namespace transform { namespace transform {
Pass LowerThreadAllreduce() { Pass LowerThreadAllreduce() {
...@@ -356,10 +348,6 @@ Pass LowerThreadAllreduce() { ...@@ -356,10 +348,6 @@ Pass LowerThreadAllreduce() {
auto target = f->GetAttr<Target>(tvm::attr::kTarget); auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined()) CHECK(target.defined())
<< "LowerThreadAllreduce: Require the target attribute"; << "LowerThreadAllreduce: Require the target attribute";
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "LowerThreadAllreeduce: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
n->body = ThreadAllreduceBuilder(target->thread_warp_size)(n->body); n->body = ThreadAllreduceBuilder(target->thread_warp_size)(n->body);
return f; return f;
}; };
......
...@@ -21,18 +21,22 @@ ...@@ -21,18 +21,22 @@
* \file split_host_device.cc * \file split_host_device.cc
* \brief Split device function from host. * \brief Split device function from host.
*/ */
#include <tvm/ir/transform.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/lowered_func.h> #include <tvm/tir/transform.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/runtime/module.h> #include <tvm/target/target.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <unordered_map> #include <unordered_map>
namespace tvm { namespace tvm {
namespace tir { namespace tir {
// use/def analysis, also delete unreferenced lets // use/def analysis, also delete unreferenced lets
class IRUseDefAnalysis : public StmtExprMutator { class VarUseDefAnalysis : public StmtExprMutator {
public: public:
Stmt VisitStmt_(const AttrStmtNode* op) final { Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) { if (op->attr_key == attr::thread_extent) {
...@@ -156,8 +160,27 @@ class IRUseDefAnalysis : public StmtExprMutator { ...@@ -156,8 +160,27 @@ class IRUseDefAnalysis : public StmtExprMutator {
std::unordered_map<const VarNode*, int> def_count_; std::unordered_map<const VarNode*, int> def_count_;
}; };
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
VarUseDefAnalysis m;
for (Var arg : args) {
m.use_count_[arg.get()] = 0;
}
m(stmt);
return m.undefined_;
}
class HostDeviceSplitter : public StmtMutator { class HostDeviceSplitter : public StmtMutator {
public: public:
explicit HostDeviceSplitter(IRModuleNode* device_mod,
Target device_target,
std::string name_prefix)
: device_mod_(device_mod),
device_target_(device_target),
name_prefix_(name_prefix) {
}
Stmt VisitStmt_(const AllocateNode* op) final { Stmt VisitStmt_(const AllocateNode* op) final {
handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0); handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0);
return StmtMutator::VisitStmt_(op); return StmtMutator::VisitStmt_(op);
...@@ -172,86 +195,128 @@ class HostDeviceSplitter : public StmtMutator { ...@@ -172,86 +195,128 @@ class HostDeviceSplitter : public StmtMutator {
return StmtMutator::VisitStmt_(op); return StmtMutator::VisitStmt_(op);
} }
Array<LoweredFunc> Split(LoweredFunc f) {
CHECK_EQ(f->func_type, kMixedFunc);
for (auto kv : f->handle_data_type) {
handle_data_type_[kv.first.get()] = kv.second;
}
name_ = f->name;
ObjectPtr<LoweredFuncNode> n =
make_object<LoweredFuncNode>(*f.operator->());
n->body = operator()(f->body);
n->func_type = kHostFunc;
Array<LoweredFunc> ret{LoweredFunc(n)};
for (LoweredFunc x : device_funcs_) {
ret.push_back(x);
}
return ret;
}
private: private:
Stmt SplitDeviceFunc(Stmt body) { Stmt SplitDeviceFunc(Stmt body) {
std::ostringstream os; std::ostringstream os;
os << name_ << "_kernel" << device_funcs_.size(); os << name_prefix_ << "_kernel" << device_func_counter_++;
ObjectPtr<LoweredFuncNode> n = make_object<LoweredFuncNode>(); std::string kernel_symbol = os.str();
// isolate the device function. // isolate the device function.
IRUseDefAnalysis m; VarUseDefAnalysis m;
m.visit_thread_extent_ = false; m.visit_thread_extent_ = false;
n->body = m(std::move(body)); body = m(std::move(body));
n->name = os.str();
n->func_type = kDeviceFunc; Array<Var> params;
n->thread_axis = m.thread_axis_; Array<PrimExpr> arguments;
Map<tir::Var, PrimExpr> remap_vars;
// Strictly order the arguments: Var pointers, positional arguments. // Strictly order the arguments: Var pointers, positional arguments.
for (Var v : m.undefined_) { for (Var var : m.undefined_) {
if (v.dtype().is_handle()) { if (var.dtype().is_handle()) {
n->args.push_back(v); // Create a new version of v.
// mark handle data type. auto it = handle_data_type_.find(var.get());
auto it = handle_data_type_.find(v.get());
if (it != handle_data_type_.end()) { if (it != handle_data_type_.end()) {
n->handle_data_type.Set(v, it->second); tir::Var new_var(var->name_hint,
PointerType(PrimType((*it).second->dtype)));
params.push_back(new_var);
remap_vars.Set(var, new_var);
} else {
params.push_back(var);
} }
arguments.push_back(var);
} }
} }
for (Var v : m.undefined_) { // positional arguments
if (!v.dtype().is_handle()) { for (Var var : m.undefined_) {
n->args.push_back(v); if (!var.dtype().is_handle()) {
params.push_back(var);
arguments.push_back(var);
} }
} }
LoweredFunc f_device(n); PrimFunc device_func(params, Substitute(body, remap_vars));
device_func = WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, m.thread_axis_);
device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv,
Integer(CallingConv::kDeviceKernelLaunch));
device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol,
runtime::String(kernel_symbol));
device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1));
device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_);
device_mod_->Add(GlobalVar(kernel_symbol), device_func);
// generate calls to the device function
Array<PrimExpr> call_args; Array<PrimExpr> call_args;
call_args.push_back(StringImmNode::make(f_device->name)); call_args.push_back(StringImmNode::make(kernel_symbol));
for (Var arg : n->args) { for (PrimExpr arg : arguments) {
call_args.push_back(arg); call_args.push_back(arg);
} }
for (PrimExpr ext : m.thread_extent_) { for (PrimExpr ext : m.thread_extent_) {
call_args.push_back(ext); call_args.push_back(ext);
} }
device_funcs_.emplace_back(f_device);
return EvaluateNode::make(CallNode::make( return EvaluateNode::make(CallNode::make(
DataType::Int(32), intrinsic::tvm_call_packed, DataType::Int(32), intrinsic::tvm_call_packed,
call_args, CallNode::Intrinsic)); call_args, CallNode::Intrinsic));
} }
// function name // target ir module
std::string name_; IRModuleNode* device_mod_;
// the device functions // Device target
Target device_target_;
// function name hint
std::string name_prefix_;
// Number of device functions.
int device_func_counter_{0};
std::vector<LoweredFunc> device_funcs_; std::vector<LoweredFunc> device_funcs_;
std::unordered_map<const VarNode*, PrimExpr> handle_data_type_; std::unordered_map<const VarNode*, PrimExpr> handle_data_type_;
}; };
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) { PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) {
IRUseDefAnalysis m; auto target = func->GetAttr<Target>(tvm::attr::kTarget);
for (Var arg : args) { CHECK(target.defined())
m.use_count_[arg.get()] = 0; << "SplitHostDevice: Require the target attribute";
} auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
m(stmt); CHECK(global_symbol.defined())
return m.undefined_; << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute";
HostDeviceSplitter splitter(
device_mod, target, static_cast<std::string>(global_symbol));
auto* n = func.CopyOnWrite();
n->body = splitter(std::move(n->body));
// set the host target to None.
func = WithAttr(std::move(func), tvm::attr::kTarget, Target(nullptr));
return std::move(func);
} }
Array<LoweredFunc> SplitHostDevice(LoweredFunc func) {
return HostDeviceSplitter().Split(func);
namespace transform {
Pass SplitHostDevice() {
auto pass_func = [](IRModule m, PassContext ctx) {
IRModuleNode* mptr = m.CopyOnWrite();
std::vector<std::pair<GlobalVar, PrimFunc> > updates;
for (const auto& kv : mptr->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
PrimFunc func = GetRef<PrimFunc>(n);
auto updated_func = SplitHostDevice(std::move(func), mptr);
updates.push_back({kv.first, updated_func});
}
}
for (const auto& pair : updates) {
mptr->Add(pair.first, pair.second, true);
}
return m;
};
return tvm::transform::CreateModulePass(
pass_func, 0, "tir.SplitHostDevice", {});
} }
TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice")
.set_body_typed(SplitHostDevice);
} // namespace transform
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -218,26 +218,19 @@ Stmt InferFragment(Stmt stmt) { ...@@ -218,26 +218,19 @@ Stmt InferFragment(Stmt stmt) {
return stmt; return stmt;
} }
LoweredFunc InferFragment(LoweredFunc f) {
CHECK_NE(f->func_type, kHostFunc);
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = InferFragment(f->body);
return LoweredFunc(n);
}
namespace transform { namespace transform {
Pass InferFragement() { Pass InferFragment() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite(); auto* n = f.CopyOnWrite();
n->body = InferFragment(std::move(n->body)); n->body = InferFragment(std::move(n->body));
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tir.InferFragement", {}); return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {});
} }
TVM_REGISTER_GLOBAL("tir.transform.InferFragement") TVM_REGISTER_GLOBAL("tir.transform.InferFragment")
.set_body_typed(InferFragement); .set_body_typed(InferFragment);
} // namespace transform } // namespace transform
} // namespace tir } // namespace tir
......
...@@ -374,13 +374,6 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) { ...@@ -374,13 +374,6 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt)); return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt));
} }
LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
CHECK_NE(f->func_type, kHostFunc);
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = ThreadSync(f->body, storage_scope);
return LoweredFunc(n);
}
namespace transform { namespace transform {
Pass ThreadSync(std::string storage_scope) { Pass ThreadSync(std::string storage_scope) {
......
...@@ -28,7 +28,7 @@ def test_loop_dependent_allocate(): ...@@ -28,7 +28,7 @@ def test_loop_dependent_allocate():
s[AA].compute_at(s[C], s[C].op.axis[0]) s[AA].compute_at(s[C], s[C].op.axis[0])
# this line should fail due to IRUseDefAnalysis sees an allocate statement # this line should fail due to IRUseDefAnalysis sees an allocate statement
# referencing undefined variable # referencing undefined variable
tvm.lower(s, [A,C]) tvm.lower(s, [A, C])
if __name__ == "__main__": if __name__ == "__main__":
test_loop_dependent_allocate() test_loop_dependent_allocate()
...@@ -41,7 +41,9 @@ def test_double_buffer(): ...@@ -41,7 +41,9 @@ def test_double_buffer():
assert isinstance(stmt.body.body, tvm.tir.Allocate) assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2 assert stmt.body.body.extents[0].value == 2
f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
f = tvm.tir.ir_pass.ThreadSync(f, "shared") mod = tvm.testing.LoweredFuncsToIRModule([f])
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0] count = [0]
def count_sync(op): def count_sync(op):
if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
......
...@@ -93,7 +93,10 @@ def test_flatten_double_buffer(): ...@@ -93,7 +93,10 @@ def test_flatten_double_buffer():
assert isinstance(stmt.body.body, tvm.tir.Allocate) assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2 assert stmt.body.body.extents[0].value == 2
f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
f = tvm.tir.ir_pass.ThreadSync(f, "shared") f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
mod = tvm.testing.LoweredFuncsToIRModule([f])
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0] count = [0]
def count_sync(op): def count_sync(op):
if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
......
...@@ -33,16 +33,15 @@ def test_lower_warp_mem(): ...@@ -33,16 +33,15 @@ def test_lower_warp_mem():
xo, xi = s[AA].split(s[AA].op.axis[0], 32) xo, xi = s[AA].split(s[AA].op.axis[0], 32)
s[AA].bind(xi, tx) s[AA].bind(xi, tx)
f = tvm.lower(s, [A, B])
fhost, fdevice = tvm.tir.ir_pass.SplitHostDevice(f)
# temp adapter to convert loweredFunc to IRModule
# to test passes in the new style.
fname = fdevice.name
mod = tvm.testing.LoweredFuncsToIRModule([fdevice])
cuda_target = tvm.target.create("cuda") cuda_target = tvm.target.create("cuda")
assert cuda_target.thread_warp_size == 32 assert cuda_target.thread_warp_size == 32
mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target)) f = tvm.lower(s, [A, B], name="f")
mod = tvm.testing.LoweredFuncsToIRModule([f])
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
mod = tvm.IRModule.from_expr(fdevice)
fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["main"] fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["main"]
assert(fdevice.body.body.value.value == "local") assert(fdevice.body.body.value.value == "local")
assert(fdevice.body.body.body.extents[0].value == 2) assert(fdevice.body.body.body.extents[0].value == 2)
......
...@@ -38,13 +38,13 @@ def test_thread_storage_sync(): ...@@ -38,13 +38,13 @@ def test_thread_storage_sync():
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True) f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
flist = tvm.tir.ir_pass.SplitHostDevice(f) cuda_target = tvm.target.create("cuda")
f = flist[1]
fname = f.name
mod = tvm.testing.LoweredFuncsToIRModule([f])
mod = tvm.testing.LoweredFuncsToIRModule([f])
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
mod = tvm.IRModule.from_expr(fdevice)
cuda_target = tvm.target.create("cuda") cuda_target = tvm.target.create("cuda")
mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target))
f = tvm.tir.transform.ThreadSync("shared")(mod)["main"] f = tvm.tir.transform.ThreadSync("shared")(mod)["main"]
body_list = tvm.tir.stmt_list(f.body.body.body.body) body_list = tvm.tir.stmt_list(f.body.body.body.body)
assert(body_list[1].value.name == "tvm_storage_sync") assert(body_list[1].value.name == "tvm_storage_sync")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment