Commit 44bffdb3 by Tianqi Chen Committed by GitHub

[REFACTOR][TIR] Migrate low-level pass functions to Pass Manager, (#5213)

- Migrate LowerTVMBultin
- Migrate inferFragment, LowerThreadAllreduce
- Migrate ThreadSync
- Refactor target::Build to directly take IRModule.
- Remove un-used legacy functions.
parent 88d2f34b
...@@ -57,20 +57,6 @@ TVM_DLL Array<tir::LoweredFunc> lower( ...@@ -57,20 +57,6 @@ TVM_DLL Array<tir::LoweredFunc> lower(
const std::string& name, const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds, const std::unordered_map<te::Tensor, tir::Buffer>& binds,
const BuildConfig& config); const BuildConfig& config);
/*!
* \brief Split host/device function and running necessary pass before build
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \return The Array<Array<LoweredFunc>> with 2 elements. First is host function Array,
second is device function array
*/
TVM_DLL Array<Array<tir::LoweredFunc> > split_dev_host_funcs(
const Array<tir::LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config);
/*! /*!
* \brief Build a device and host module for a specific target from an array of lowered functions. * \brief Build a device and host module for a specific target from an array of lowered functions.
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define TVM_TARGET_CODEGEN_H_ #define TVM_TARGET_CODEGEN_H_
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/ir/module.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/lowered_func.h> #include <tvm/tir/lowered_func.h>
#include <tvm/target/target.h> #include <tvm/target/target.h>
...@@ -41,15 +42,24 @@ using runtime::TVMArgs; ...@@ -41,15 +42,24 @@ using runtime::TVMArgs;
using runtime::TVMRetValue; using runtime::TVMRetValue;
/*! /*!
* \brief Temporary backward compatible function to convert a list
* of LoweredFunc to a IRModule of PrimfFuncs
* \param funcs The input lowered function.
* \return The IRModule.
*
* \note This function is only used for code refactor and will be
* removed once the refactor completes.
*/
IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs);
/*!
* \brief Build a module from array of lowered function. * \brief Build a module from array of lowered function.
* \param funcs The functions to be built. * \param mod The Module to be built
* \param target The target to be built. * \param target The target to be built.
* \return The builded module. * \return The result runtime::Module.
*
* \note Calls global API function "_codegen_build_" + target
*/ */
runtime::Module Build(const Array<tir::LoweredFunc>& funcs, runtime::Module Build(IRModule mod, const Target& target);
const std::string& target);
/*! /*!
* \brief Pack imported device library to a C file. * \brief Pack imported device library to a C file.
* Compile the C file and link with the host library * Compile the C file and link with the host library
......
...@@ -477,12 +477,6 @@ LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map); ...@@ -477,12 +477,6 @@ LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
*/ */
LoweredFunc LowerTVMBuiltin(LoweredFunc f); LoweredFunc LowerTVMBuiltin(LoweredFunc f);
/*!
* \brief Combine context function calls.
* \param f The host function to be lowered.
* \return Transformed function.
*/
LoweredFunc CombineContextCall(LoweredFunc f);
/*! /*!
* \brief Rewrite the pointer content type of arguments, * \brief Rewrite the pointer content type of arguments,
...@@ -496,7 +490,6 @@ LoweredFunc CombineContextCall(LoweredFunc f); ...@@ -496,7 +490,6 @@ LoweredFunc CombineContextCall(LoweredFunc f);
*/ */
LoweredFunc PointerValueTypeRewrite(LoweredFunc f); LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
/*! /*!
* \brief Rewrite the pointer content type of arguments, * \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use * as well as Alloc internal to the function to use
...@@ -510,23 +503,6 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f); ...@@ -510,23 +503,6 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
PrimFunc PointerValueTypeRewrite(PrimFunc f); PrimFunc PointerValueTypeRewrite(PrimFunc f);
/*! /*!
* \brief Lower attached storage access information on device.
* Do this pass after all storage access analysis finish.
*
* \param func The device function to be lowered.
* \return Transformed function.
*/
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func);
/*!
* \brief Lower intrinsic function calls.
* \param f The device function to be lowered.
* \param target The target device.
* \return Transformed function.
*/
LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);
/*!
* \brief Lower custom datatypes. * \brief Lower custom datatypes.
* *
* See tvm::datatypes::Registry for more information on adding custom datatypes. * See tvm::datatypes::Registry for more information on adding custom datatypes.
...@@ -546,13 +522,6 @@ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target); ...@@ -546,13 +522,6 @@ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
LoweredFunc InferFragment(LoweredFunc f); LoweredFunc InferFragment(LoweredFunc f);
/*! /*!
* \brief skip assert stmt generation
* \param f The function to be transformed.
* \return Transformed function.
*/
LoweredFunc SkipAssert(LoweredFunc f);
/*!
* \brief Verify if memory accesses are legal for a specific target device type. * \brief Verify if memory accesses are legal for a specific target device type.
* *
* In the case that tgt is cuda, if not all workload is bound with * In the case that tgt is cuda, if not all workload is bound with
......
...@@ -59,11 +59,40 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< ...@@ -59,11 +59,40 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const tvm::Array<tvm::PrimExpr>& required); const tvm::Array<tvm::PrimExpr>& required);
/*! /*!
* \brief Combine context calls in the host function. * \brief skip assert stmt.
* *
* \return The pass. * \return The pass.
*/ */
TVM_DLL Pass CombineContextCall(); TVM_DLL Pass SkipAssert();
/*!
* \brief Insert sync between parallel read/write of shared buffers.
*
* \param storage_scope The storage scope considered.
* \return The pass.
*/
TVM_DLL Pass ThreadSync(std::string storage_scope);
/*!
* \brief Lower cross thread alleduce.
*
* \return The pass.
*/
TVM_DLL Pass LowerThreadAllreduce();
/*!
* \brief Infer the TensorCore fragment infomation using tensor intrinsics
*
* \return The pass.
*/
TVM_DLL Pass InferFragment();
/*!
* \brief Lower builtin intrinsics.
* \return The pass.
*/
TVM_DLL Pass LowerTVMBuiltin();
/*! /*!
* \brief Lower the target specific function intrinsics in each of the function. * \brief Lower the target specific function intrinsics in each of the function.
...@@ -73,6 +102,12 @@ TVM_DLL Pass CombineContextCall(); ...@@ -73,6 +102,12 @@ TVM_DLL Pass CombineContextCall();
TVM_DLL Pass LowerIntrin(); TVM_DLL Pass LowerIntrin();
/*! /*!
* \brief Lower warp memory access to low-level device related function calls.
* \return The pass.
*/
TVM_DLL Pass LowerWarpMemory();
/*!
* \brief Lower attached storage access information on device. * \brief Lower attached storage access information on device.
* *
* \note Run this pass after all storage access analysis finish. * \note Run this pass after all storage access analysis finish.
...@@ -82,10 +117,11 @@ TVM_DLL Pass LowerIntrin(); ...@@ -82,10 +117,11 @@ TVM_DLL Pass LowerIntrin();
TVM_DLL Pass LowerDeviceStorageAccessInfo(); TVM_DLL Pass LowerDeviceStorageAccessInfo();
/*! /*!
* \brief Lower warp memory access to low-level device related function calls. * \brief Combine context calls in the host function.
*
* \return The pass. * \return The pass.
*/ */
TVM_DLL Pass LowerWarpMemory(); TVM_DLL Pass CombineContextCall();
/*! /*!
......
...@@ -222,6 +222,15 @@ def _build_for_device(flist, target, target_host): ...@@ -222,6 +222,15 @@ def _build_for_device(flist, target, target_host):
mdev : tvm.module mdev : tvm.module
A module that contains device code. A module that contains device code.
""" """
@tvm.tir.transform.prim_func_pass(opt_level=0)
class BindTarget:
def __init__(self, target):
self.target = target
# pylint: disable=unused-argument
def transform_function(self, func, mod, ctx):
return func.with_attr("target", self.target)
target = _target.create(target) target = _target.create(target)
device_type = ndarray.context(target.target_name, 0).device_type device_type = ndarray.context(target.target_name, 0).device_type
fhost = [] fhost = []
...@@ -250,30 +259,39 @@ def _build_for_device(flist, target, target_host): ...@@ -250,30 +259,39 @@ def _build_for_device(flist, target, target_host):
else: else:
raise ValueError("unknown function type %d" % func.func_type) raise ValueError("unknown function type %d" % func.func_type)
for i, func in enumerate(fdevice):
warp_size = target.thread_warp_size
fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)
if "gpu" in target.keys and not fdevice: if "gpu" in target.keys and not fdevice:
warnings.warn( warnings.warn(
"Specified target %s, but cannot find device code, did you do " "Specified target %s, but cannot find device code, did you do "
"bind?" % target) "bind?" % target)
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost] fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
if device_type == ndarray.cpu(0).device_type and target_host == target: if device_type == ndarray.cpu(0).device_type and target_host == target:
assert not fdevice assert not fdevice
target_host = _target.create(target_host) target_host = _target.create(target_host)
fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
mdev = codegen.build_module(fdevice, str(target)) if fdevice else None
return fhost, mdev # device optimizations
mod_dev = tvm.testing.LoweredFuncsToIRModule(fdevice)
opt_device = tvm.ir.transform.Sequential(
[BindTarget(target),
tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin()])
mod_dev = opt_device(mod_dev)
# host optimizations
mod_host = tvm.testing.LoweredFuncsToIRModule(fhost)
opt_host = tvm.ir.transform.Sequential(
[BindTarget(target_host),
tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin(),
tvm.tir.transform.CombineContextCall()])
mod_host = opt_host(mod_host)
rt_mod_dev = codegen.build_module(mod_dev, target) if fdevice else None
return mod_host, rt_mod_dev
def build(inputs, def build(inputs,
...@@ -402,19 +420,19 @@ def build(inputs, ...@@ -402,19 +420,19 @@ def build(inputs,
if not target_host: if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
fhost_all = [] mod_host_all = tvm.IRModule({})
device_modules = [] device_modules = []
for tar, flist in target_flist.items(): for tar, flist in target_flist.items():
fhost, mdev = _build_for_device(flist, tar, target_host) mod_host, mdev = _build_for_device(flist, tar, target_host)
# Save the current lowered functions of the host and the device module. mod_host_all.update(mod_host)
fhost_all += fhost
device_modules.append(mdev) device_modules.append(mdev)
# Generate a unified host module. # Generate a unified host module.
mhost = codegen.build_module(fhost_all, str(target_host)) rt_mod_host = codegen.build_module(mod_host_all, target_host)
# Import all modules. # Import all modules.
for mdev in device_modules: for mdev in device_modules:
if mdev: if mdev:
mhost.import_module(mdev) rt_mod_host.import_module(mdev)
return mhost return rt_mod_host
...@@ -17,15 +17,16 @@ ...@@ -17,15 +17,16 @@
# under the License. # under the License.
"""Code generation related functions.""" """Code generation related functions."""
from . import _ffi_api from . import _ffi_api
from . import target as _tgt
def build_module(lowered_func, target): def build_module(mod, target):
"""Build lowered_func into Module. """Build IRModule into Module.
Parameters Parameters
---------- ----------
lowered_func : LoweredFunc mod : tvm.IRModule
The lowered function The ir module.
target : str target : str
The target module type. The target module type.
...@@ -35,7 +36,8 @@ def build_module(lowered_func, target): ...@@ -35,7 +36,8 @@ def build_module(lowered_func, target):
module : runtime.Module module : runtime.Module
The corressponding module. The corressponding module.
""" """
return _ffi_api.Build(lowered_func, target) target = _tgt.create(target) if isinstance(target, str) else target
return _ffi_api.Build(mod, target)
def llvm_lookup_intrinsic_id(name): def llvm_lookup_intrinsic_id(name):
......
...@@ -16,19 +16,78 @@ ...@@ -16,19 +16,78 @@
# under the License. # under the License.
"""Wrapping existing transformations.""" """Wrapping existing transformations."""
# pylint: disable=invalid-name # pylint: disable=invalid-name
from . import _ffi_api from . import _ffi_api
def CombineContextCall(): def SkipAssert():
"""Combine context calls in the host function. """Skip assert stmt.
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.ir.transform.Pass
The result pass The result pass
""" """
return _ffi_api.CombineContextCall() return _ffi_api.SkipAssert()
def ThreadSync(storage_scope):
""" Insert sync between parallel read/write of shared buffers.
Parameters
----------
storage_scope: str
The target storage scope.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.ThreadSync(storage_scope)
def LowerThreadAllreduce():
"""Lower cross thread alleduce.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.LowerThreadAllreduce()
def InferFragment():
""" Infer the TensorCore fragment infomation using tensor intrinsics.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.InferFragment()
def LowerWarpMemory():
"""Lower warp memory access to low-level device related function calls.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.LowerWarpMemory()
def LowerTVMBuiltin():
"""Lower tvm builtin intrinsics.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.LowerTVMBuiltin()
def LowerIntrin(): def LowerIntrin():
...@@ -57,15 +116,15 @@ def LowerDeviceStorageAccessInfo(): ...@@ -57,15 +116,15 @@ def LowerDeviceStorageAccessInfo():
return _ffi_api.LowerDeviceStorageAccessInfo() return _ffi_api.LowerDeviceStorageAccessInfo()
def LowerWarpMemory(): def CombineContextCall():
"""Lower warp memory access to low-level device related function calls. """Combine context calls in the host function.
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.ir.transform.Pass
The result pass The result pass
""" """
return _ffi_api.LowerWarpMemory() return _ffi_api.CombineContextCall()
def NarrowDataType(): def NarrowDataType():
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <tvm/driver/driver_api.h> #include <tvm/driver/driver_api.h>
#include <tvm/te/operation.h> #include <tvm/te/operation.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/target/codegen.h> #include <tvm/target/codegen.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
...@@ -174,10 +176,20 @@ Array<LoweredFunc> lower(te::Schedule sch, ...@@ -174,10 +176,20 @@ Array<LoweredFunc> lower(te::Schedule sch,
return Array<LoweredFunc>({ tir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) }); return Array<LoweredFunc>({ tir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) });
} }
Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
const Target& target, transform::Pass BindTarget(Target target) {
const Target& target_host, auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
const BuildConfig& config) { return WithAttr(std::move(f), tvm::attr::kTarget, target);
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {});
}
std::pair<IRModule, IRModule>
split_dev_host_funcs(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
std::unordered_set<std::string> all_names; std::unordered_set<std::string> all_names;
for (const auto& x : funcs) { for (const auto& x : funcs) {
CHECK(all_names.count(x->name) == 0) CHECK(all_names.count(x->name) == 0)
...@@ -217,13 +229,6 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs, ...@@ -217,13 +229,6 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
} }
} }
for (size_t i = 0; i < fdevice.size(); i++) {
auto warp_size = target->thread_warp_size;
auto func = fdevice[i];
func = tir::LowerWarpMemory(fdevice[i], warp_size);
fdevice.Set(i, func);
}
auto keys = target->keys(); auto keys = target->keys();
bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
if (target_is_gpu && fdevice.size() == 0) { if (target_is_gpu && fdevice.size() == 0) {
...@@ -232,53 +237,46 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs, ...@@ -232,53 +237,46 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
<< " but cannot find device code. Did you forget to bind?"; << " but cannot find device code. Did you forget to bind?";
} }
for (size_t i = 0; i < fdevice.size(); ++i) {
auto func = fdevice[i];
func = tir::LowerIntrin(func, target->target_name);
fdevice.Set(i, func);
}
if (target->device_type == target::llvm()->device_type && if (target->device_type == target::llvm()->device_type &&
target_host == target) { target_host == target) {
CHECK(fdevice.empty()) << "No device code should be generated when target " CHECK(fdevice.empty()) << "No device code should be generated when target "
<< "and host_target are both llvm target." << "and host_target are both llvm target."
<< "\n"; << "\n";
} }
for (size_t i = 0; i < fdevice.size(); ++i) {
auto func = fdevice[i];
func = tir::LowerDeviceStorageAccessInfo(func);
fdevice.Set(i, func);
}
for (size_t i = 0; i < fhost.size(); ++i) { for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i]; auto func = fhost[i];
func = tir::BindDeviceType(func, target->device_type); func = tir::BindDeviceType(func, target->device_type);
func = tir::LowerDeviceStorageAccessInfo(func);
func = tir::LowerTVMBuiltin(func);
fhost.Set(i, func); fhost.Set(i, func);
} }
for (size_t i = 0; i < fhost.size(); ++i) { // host pipeline
auto func = fhost[i]; auto mhost = codegen::ToIRModule(fhost);
func = tir::LowerIntrin(func, target_host->target_name); auto host_pass_list = {
func = tir::LowerDeviceStorageAccessInfo(func); BindTarget(target_host),
func = tir::CombineContextCall(func); tir::transform::LowerTVMBuiltin(),
fhost.Set(i, func); tir::transform::LowerIntrin(),
} tir::transform::LowerDeviceStorageAccessInfo(),
return {fhost, fdevice}; tir::transform::CombineContextCall(),
};
auto opt_host = transform::Sequential(host_pass_list);
mhost = opt_host(mhost);
// device pipeline
auto mdevice = codegen::ToIRModule(fdevice);
auto device_pass_list = {
BindTarget(target),
tir::transform::LowerWarpMemory(),
tir::transform::LowerIntrin(),
tir::transform::LowerDeviceStorageAccessInfo(),
};
auto opt_device = transform::Sequential(device_pass_list);
mdevice = opt_device(mdevice);
return {mhost, mdevice};
} }
// Create a module for a specific device (target). The lowered functions
// associated with the host is returned as well.
runtime::Module DeviceBuild(const Array<LoweredFunc>& fdevice,
const Target& target) {
if (!fdevice.empty()) {
return codegen::Build(fdevice, target->str());
} else {
return runtime::Module(nullptr);
}
}
// Build for heterogeneous execution. // Build for heterogeneous execution.
runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs, runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs,
...@@ -301,20 +299,21 @@ runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs, ...@@ -301,20 +299,21 @@ runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs,
target_host_val = DefaultTargetHost(target_host_val); target_host_val = DefaultTargetHost(target_host_val);
} }
IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>());
for (const auto& it : inputs) { for (const auto& it : inputs) {
auto host_dev_funcs = auto pair =
split_dev_host_funcs(it.second, it.first, target_host_val, config); split_dev_host_funcs(it.second, it.first, target_host_val, config);
auto& fhost = host_dev_funcs[0]; auto& mhost = pair.first;
auto& fdevice = host_dev_funcs[1]; auto& mdevice = pair.second;
// Get the module for a certain target.
runtime::Module mdev = DeviceBuild(fdevice, it.first); mhost_all->Update(mhost);
for (const auto& it : fhost) { if (mdevice->functions.size() != 0) {
fhost_all.push_back(it); device_modules.push_back(codegen::Build(mdevice, it.first));
} }
device_modules.push_back(mdev);
} }
runtime::Module mhost = codegen::Build(fhost_all, target_host_val->str()); runtime::Module mhost = codegen::Build(mhost_all, target_host_val);
// Import all modules // Import all modules
for (const auto& it : device_modules) { for (const auto& it : device_modules) {
if (it.operator->()) { if (it.operator->()) {
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/function.h> #include <tvm/tir/function.h>
#include <tvm/runtime/container.h> #include <tvm/runtime/container.h>
...@@ -42,18 +43,6 @@ ...@@ -42,18 +43,6 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
// The new build function.
// adapt the old function to the new one
runtime::Module BuildForIRModule(const IRModule& module,
const Target& target) {
std::string build_f_name = "target.build." + target->target_name;
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
CHECK(bf != nullptr)
<< "target.build." << target << " is not enabled";
return (*bf)(module, target->str());
}
// convert legacy LoweredFunc to PrimFunc. // convert legacy LoweredFunc to PrimFunc.
tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) { tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) {
// remap args to attach type annotations. // remap args to attach type annotations.
...@@ -97,24 +86,16 @@ IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs) { ...@@ -97,24 +86,16 @@ IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs) {
return IRModule(functions); return IRModule(functions);
} }
runtime::Module Build(const Array<tir::LoweredFunc>& funcs, runtime::Module Build(IRModule mod, const Target& target) {
const std::string& target) {
std::string mode = target;
size_t pos = mode.find(' ');
if (pos != std::string::npos) {
mode = mode.substr(0, pos);
}
Array<tir::LoweredFunc> transformed_funcs;
if (BuildConfig::Current()->disable_assert) { if (BuildConfig::Current()->disable_assert) {
for (const auto& x : funcs) { mod = tir::transform::SkipAssert()(mod);
auto func = tir::SkipAssert(x);
transformed_funcs.push_back(func);
}
} }
std::string build_f_name = "target.build." + target->target_name;
return BuildForIRModule( // the build function.
transformed_funcs.size() != 0 ? ToIRModule(transformed_funcs) : ToIRModule(funcs), const PackedFunc* bf = runtime::Registry::Get(build_f_name);
Target::Create(target)); CHECK(bf != nullptr)
<< "target.build." << target << " is not enabled";
return (*bf)(mod, target->str());
} }
/*! \brief Helper class to serialize module */ /*! \brief Helper class to serialize module */
...@@ -300,13 +281,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, ...@@ -300,13 +281,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod,
} }
TVM_REGISTER_GLOBAL("target.Build") TVM_REGISTER_GLOBAL("target.Build")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(Build);
if (args[0].IsObjectRef<tir::LoweredFunc>()) {
*ret = Build({args[0]}, args[1]);
} else {
*ret = Build(args[0], args[1]);
}
});
TVM_REGISTER_GLOBAL("testing.LoweredFuncsToIRModule") TVM_REGISTER_GLOBAL("testing.LoweredFuncsToIRModule")
.set_body_typed(ToIRModule); .set_body_typed(ToIRModule);
......
...@@ -135,7 +135,6 @@ REGISTER_PASS(SplitHostDevice); ...@@ -135,7 +135,6 @@ REGISTER_PASS(SplitHostDevice);
REGISTER_PASS(StorageRewrite); REGISTER_PASS(StorageRewrite);
REGISTER_PASS(CoProcSync); REGISTER_PASS(CoProcSync);
REGISTER_PASS(LowerStorageAccessInfo); REGISTER_PASS(LowerStorageAccessInfo);
REGISTER_PASS(LowerDeviceStorageAccessInfo)
REGISTER_PASS(InjectVirtualThread); REGISTER_PASS(InjectVirtualThread);
REGISTER_PASS(InjectPrefetch); REGISTER_PASS(InjectPrefetch);
REGISTER_PASS(InjectDoubleBuffer); REGISTER_PASS(InjectDoubleBuffer);
...@@ -143,12 +142,8 @@ REGISTER_PASS(LoopPartition); ...@@ -143,12 +142,8 @@ REGISTER_PASS(LoopPartition);
REGISTER_PASS(RemoveNoOp); REGISTER_PASS(RemoveNoOp);
REGISTER_PASS(LiftAttrScope); REGISTER_PASS(LiftAttrScope);
REGISTER_PASS(LowerThreadAllreduce); REGISTER_PASS(LowerThreadAllreduce);
REGISTER_PASS(LowerWarpMemory);
REGISTER_PASS(RemapThreadAxis); REGISTER_PASS(RemapThreadAxis);
REGISTER_PASS(LowerIntrin);
REGISTER_PASS(LowerCustomDatatypes); REGISTER_PASS(LowerCustomDatatypes);
REGISTER_PASS(LowerTVMBuiltin);
REGISTER_PASS(CombineContextCall);
REGISTER_PASS(VerifyMemory); REGISTER_PASS(VerifyMemory);
REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(DecorateDeviceScope);
......
...@@ -109,18 +109,13 @@ class ContextCallCombiner final : public StmtExprMutator { ...@@ -109,18 +109,13 @@ class ContextCallCombiner final : public StmtExprMutator {
std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual> ctx_map_; std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual> ctx_map_;
}; };
LoweredFunc CombineContextCall(LoweredFunc f) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = ContextCallCombiner().Combine(n->body);
return LoweredFunc(n);
}
namespace transform { namespace transform {
Pass CombineContextCall() { Pass CombineContextCall() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite(); auto* n = f.CopyOnWrite();
n->body = ContextCallCombiner().Combine(n->body); n->body = ContextCallCombiner().Combine(std::move(n->body));
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {}); return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {});
......
...@@ -142,11 +142,6 @@ Stmt LowerStorageAccessInfo(Stmt stmt) { ...@@ -142,11 +142,6 @@ Stmt LowerStorageAccessInfo(Stmt stmt) {
return StorageAccessInfoLower()(std::move(stmt)); return StorageAccessInfoLower()(std::move(stmt));
} }
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = LowerStorageAccessInfo(f->body);
return LoweredFunc(n);
}
namespace transform { namespace transform {
......
...@@ -283,16 +283,6 @@ Stmt LowerIntrinStmt(Stmt stmt, const std::string target_name) { ...@@ -283,16 +283,6 @@ Stmt LowerIntrinStmt(Stmt stmt, const std::string target_name) {
return IntrinInjecter(&analyzer, target_name)(std::move(stmt)); return IntrinInjecter(&analyzer, target_name)(std::move(stmt));
} }
LoweredFunc
LowerIntrin(LoweredFunc f, const std::string& target) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
std::istringstream is(target);
std::string target_name;
is >> target_name;
n->body = LowerIntrinStmt(n->body, target_name);
return LoweredFunc(n);
}
namespace transform { namespace transform {
Pass LowerIntrin() { Pass LowerIntrin() {
......
...@@ -23,9 +23,14 @@ ...@@ -23,9 +23,14 @@
*/ */
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/target/target.h>
#include <tvm/runtime/registry.h>
#include <unordered_set> #include <unordered_set>
#include "ir_util.h"
#include "../pass/ir_util.h"
#include "../../arith/compute_expr.h" #include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h" #include "../../runtime/thread_storage_scope.h"
...@@ -342,5 +347,28 @@ LowerThreadAllreduce(LoweredFunc f, int warp_size) { ...@@ -342,5 +347,28 @@ LowerThreadAllreduce(LoweredFunc f, int warp_size) {
n->body = ThreadAllreduceBuilder(warp_size)(n->body); n->body = ThreadAllreduceBuilder(warp_size)(n->body);
return LoweredFunc(n); return LoweredFunc(n);
} }
namespace transform {
Pass LowerThreadAllreduce() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "LowerThreadAllreduce: Require the target attribute";
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "LowerThreadAllreeduce: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
n->body = ThreadAllreduceBuilder(target->thread_warp_size)(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {});
}
TVM_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce")
.set_body_typed(LowerThreadAllreduce);
} // namespace transform
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -18,14 +18,18 @@ ...@@ -18,14 +18,18 @@
*/ */
/*! /*!
* Lower TVM related buildin intrinsics such as packed call. * Lower TVM related builtin intrinsics such as packed call.
* \file lower_tvm_buildin.cc * \file tir/transforms/lower_tvm_buildin.cc
*/ */
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/runtime/registry.h>
#include <unordered_set> #include <unordered_set>
#include "ir_util.h"
#include "../pass/ir_util.h"
#include "../../arith/compute_expr.h" #include "../../arith/compute_expr.h"
namespace tvm { namespace tvm {
...@@ -368,11 +372,20 @@ class BuiltinLower : public StmtExprMutator { ...@@ -368,11 +372,20 @@ class BuiltinLower : public StmtExprMutator {
uint64_t max_arg_stack_{0}; uint64_t max_arg_stack_{0};
}; };
LoweredFunc LowerTVMBuiltin(LoweredFunc f) { namespace transform {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = BuiltinLower().Build(n->body); Pass LowerTVMBuiltin() {
return LoweredFunc(n); auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = BuiltinLower().Build(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {});
} }
TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin")
.set_body_typed(LowerTVMBuiltin);
} // namespace transform
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -385,14 +385,6 @@ class WarpMemoryRewriter : private StmtMutator { ...@@ -385,14 +385,6 @@ class WarpMemoryRewriter : private StmtMutator {
std::unordered_map<const VarNode*, Range> var_dom_; std::unordered_map<const VarNode*, Range> var_dom_;
}; };
LoweredFunc
LowerWarpMemory(LoweredFunc f, int warp_size) {
CHECK_EQ(f->func_type, kDeviceFunc);
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = WarpMemoryRewriter(warp_size).Rewrite(n->body);
return LoweredFunc(n);
}
namespace transform { namespace transform {
Pass LowerWarpMemory() { Pass LowerWarpMemory() {
...@@ -401,7 +393,7 @@ Pass LowerWarpMemory() { ...@@ -401,7 +393,7 @@ Pass LowerWarpMemory() {
auto target = f->GetAttr<Target>(tvm::attr::kTarget); auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined()) CHECK(target.defined())
<< "LowerWarpMemory: Require the target attribute"; << "LowerWarpMemory: Require the target attribute";
n->body = WarpMemoryRewriter(target->thread_warp_size).Rewrite(n->body); n->body = WarpMemoryRewriter(target->thread_warp_size).Rewrite(std::move(n->body));
return f; return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});
......
...@@ -19,7 +19,9 @@ ...@@ -19,7 +19,9 @@
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/runtime/registry.h>
namespace tvm { namespace tvm {
namespace tir { namespace tir {
...@@ -37,11 +39,21 @@ Stmt SkipAssert(Stmt stmt) { ...@@ -37,11 +39,21 @@ Stmt SkipAssert(Stmt stmt) {
return AssertSkipper()(std::move(stmt)); return AssertSkipper()(std::move(stmt));
} }
LoweredFunc SkipAssert(LoweredFunc f) { namespace transform {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = SkipAssert(f->body); Pass SkipAssert() {
return LoweredFunc(n); auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = AssertSkipper()(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.SkipAssert", {});
} }
TVM_REGISTER_GLOBAL("tir.transform.SkipAssert")
.set_body_typed(SkipAssert);
} // namespace transform
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -23,11 +23,15 @@ ...@@ -23,11 +23,15 @@
*/ */
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/runtime/registry.h>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "ir_util.h"
#include "storage_access.h" #include "../pass/storage_access.h"
#include "../pass/ir_util.h"
#include "../../runtime/thread_storage_scope.h" #include "../../runtime/thread_storage_scope.h"
namespace tvm { namespace tvm {
...@@ -221,5 +225,20 @@ LoweredFunc InferFragment(LoweredFunc f) { ...@@ -221,5 +225,20 @@ LoweredFunc InferFragment(LoweredFunc f) {
return LoweredFunc(n); return LoweredFunc(n);
} }
namespace transform {
Pass InferFragement() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = InferFragment(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InferFragement", {});
}
TVM_REGISTER_GLOBAL("tir.transform.InferFragement")
.set_body_typed(InferFragement);
} // namespace transform
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -18,16 +18,21 @@ ...@@ -18,16 +18,21 @@
*/ */
/*! /*!
* \file storage_sync.cc * \file thread_storage_sync.cc
*/ */
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/runtime/registry.h>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "ir_util.h"
#include "storage_access.h" #include "../pass/ir_util.h"
#include "../pass/storage_access.h"
#include "../../runtime/thread_storage_scope.h" #include "../../runtime/thread_storage_scope.h"
namespace tvm { namespace tvm {
...@@ -376,5 +381,20 @@ LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) { ...@@ -376,5 +381,20 @@ LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
return LoweredFunc(n); return LoweredFunc(n);
} }
namespace transform {
Pass ThreadSync(std::string storage_scope) {
auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = ThreadSync(std::move(n->body), storage_scope);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync", {});
}
TVM_REGISTER_GLOBAL("tir.transform.ThreadSync")
.set_body_typed(ThreadSync);
} // namespace transform
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -50,13 +50,12 @@ def test_dot(): ...@@ -50,13 +50,12 @@ def test_dot():
k = te.reduce_axis((0, n), 'k') k = te.reduce_axis((0, n), 'k')
C = te.compute((1,), lambda _: te.sum(A[k] * B[k], axis=k), name='C') C = te.compute((1,), lambda _: te.sum(A[k] * B[k], axis=k), name='C')
s = te.create_schedule(C.op) s = te.create_schedule(C.op)
fapi = lower(s, [A, B, C])
def verify(target): def verify(target):
if not tvm.runtime.enabled(target): if not tvm.runtime.enabled(target):
print("Target %s is not enabled" % target) print("Target %s is not enabled" % target)
return return
f = tvm.target.codegen.build_module(fapi, target) f = tvm.driver.build(s, [A, B, C], target)
# verify # verify
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(nn,)).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=(nn,)).astype(A.dtype), ctx)
......
...@@ -39,8 +39,9 @@ def test_dltensor_compatible(): ...@@ -39,8 +39,9 @@ def test_dltensor_compatible():
A[i + 1] = A[i] + 1 A[i + 1] = A[i] + 1
stmt = ib.get() stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True) fapi = tvm.tir.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi) mod = tvm.testing.LoweredFuncsToIRModule([fapi])
f = tvm.target.codegen.build_module(fapi, "stackvm") mod = tvm.tir.transform.LowerTVMBuiltin()(mod)
f = tvm.target.codegen.build_module(mod, "stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
aview = MyTensorView(a) aview = MyTensorView(a)
f(aview) f(aview)
......
...@@ -58,8 +58,7 @@ def test_dso_module_load(): ...@@ -58,8 +58,7 @@ def test_dso_module_load():
tvm.tir.Load(dtype, Ab.data, i) + 1, tvm.tir.Load(dtype, Ab.data, i) + 1,
i + 1)) i + 1))
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi) m = tvm.driver.build(fapi, target="llvm")
m = tvm.target.codegen.build_module(fapi, "llvm")
for name in names: for name in names:
m.save(name) m.save(name)
......
...@@ -74,9 +74,8 @@ def test_add_pipeline(): ...@@ -74,9 +74,8 @@ def test_add_pipeline():
binds = {A : Ab} binds = {A : Ab}
# BUILD and invoke the kernel. # BUILD and invoke the kernel.
f1 = tvm.lower(s, [A,B,C], name="fadd_pipeline") f1 = tvm.lower(s, [A,B,C], name="fadd_pipeline")
fsplits = [x for x in tvm.tir.ir_pass.SplitHostDevice(f1)] mhost = tvm.build(f1, target="c")
fsplits[0] = tvm.tir.ir_pass.LowerTVMBuiltin(fsplits[0])
mhost = tvm.target.codegen.build_module(fsplits[0], "c")
temp = util.tempdir() temp = util.tempdir()
path_dso = temp.relpath("temp.so") path_dso = temp.relpath("temp.so")
mhost.export_library(path_dso) mhost.export_library(path_dso)
......
...@@ -63,79 +63,27 @@ def test_add_pipeline(): ...@@ -63,79 +63,27 @@ def test_add_pipeline():
s[D].bind(xi, te.thread_axis("threadIdx.x")) s[D].bind(xi, te.thread_axis("threadIdx.x"))
s[D].bind(xo, te.thread_axis("blockIdx.x")) s[D].bind(xo, te.thread_axis("blockIdx.x"))
# compile to IR
s = s.normalize()
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
Db = tvm.tir.decl_buffer(D.shape, D.dtype, name='D')
stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, D:Db}, 64)
stmt = tvm.tir.ir_pass.Simplify(stmt)
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
fsplits = [x for x in tvm.tir.ir_pass.SplitHostDevice(fapi)]
# lower the floordiv(use stackvm rules so it works for all targets)
fsplits = [tvm.tir.ir_pass.LowerIntrin(x, "stackvm") for x in fsplits]
fsplits[0] = tvm.tir.ir_pass.LowerTVMBuiltin(fsplits[0])
def check_target(device, host="stackvm"): def check_target(device, host="stackvm"):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
if not ctx.exist: if not ctx.exist:
return return
if not tvm.runtime.enabled(host): if not tvm.runtime.enabled(host):
return return
mhost = tvm.target.codegen.build_module(fsplits[0], host) mhost = tvm.driver.build(s, [A, B, D], target=device, target_host=host)
mdev = tvm.target.codegen.build_module(fsplits[1:], device)
mhost.import_module(mdev)
code = mdev.get_source()
f = mhost.entry_func
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx)
d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
f(a, b, d)
tvm.testing.assert_allclose(
d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
def check_module_save(device, host="stackvm"):
ctx = tvm.context(device, 0)
if not ctx.exist:
return
if not tvm.runtime.enabled(host):
return
if device == "cuda":
fmt = "ptx"
elif device == "rocm":
fmt = "hsaco"
else:
fmt = device
mhost = tvm.target.codegen.build_module(fsplits[0], host)
mdev = tvm.target.codegen.build_module(fsplits[1:], device)
temp = util.tempdir()
mpath = temp.relpath("test.%s" % fmt)
mdev.save(mpath)
mdev2 = tvm.runtime.load_module(mpath)
mhost.import_module(mdev2)
f = mhost.entry_func f = mhost.entry_func
# launch the kernel. # launch the kernel.
n = 1027 n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx) b = tvm.nd.array(np.random.uniform(size=()).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx) d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx)
f(a, b, d) f(a, b, d)
tvm.testing.assert_allclose( tvm.testing.assert_allclose(
d.asnumpy(), a.asnumpy() + b.asnumpy() + 1) d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
check_target("cuda", host="stackvm")
check_target("cuda", host="llvm") check_target("cuda", host="llvm")
check_module_save("cuda", host="stackvm")
check_target("nvptx", host="llvm") check_target("nvptx", host="llvm")
check_target("vulkan", host="llvm") check_target("vulkan", host="llvm")
check_module_save("vulkan", host="stackvm")
check_target("rocm", host="llvm") check_target("rocm", host="llvm")
check_module_save("rocm", host="llvm")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -33,8 +33,7 @@ def test_static_callback(): ...@@ -33,8 +33,7 @@ def test_static_callback():
A[i] = A[i] + 1 A[i] = A[i] + 1
stmt = ib.get() stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi) f = tvm.driver.build(fapi, target="llvm")
f = tvm.target.codegen.build_module(fapi, "llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
f(a) f(a)
...@@ -57,8 +56,7 @@ def test_static_init(): ...@@ -57,8 +56,7 @@ def test_static_init():
stmt = ib.get() stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi) f = tvm.driver.build(fapi, target="llvm")
f = tvm.target.codegen.build_module(fapi, "llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
......
...@@ -22,7 +22,7 @@ def run_jit(fapi, check): ...@@ -22,7 +22,7 @@ def run_jit(fapi, check):
for target in ["llvm", "stackvm"]: for target in ["llvm", "stackvm"]:
if not tvm.runtime.enabled(target): if not tvm.runtime.enabled(target):
continue continue
f = tvm.target.codegen.build_module(fapi, target) f = tvm.driver.build(fapi, target=target)
s = f.get_source() s = f.get_source()
check(f) check(f)
...@@ -37,8 +37,6 @@ def test_stack_vm_basic(): ...@@ -37,8 +37,6 @@ def test_stack_vm_basic():
Ab = tvm.tir.decl_buffer((n, ), "float32") Ab = tvm.tir.decl_buffer((n, ), "float32")
stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0])) stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True) fapi = tvm.tir.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi)
fapi = tvm.tir.ir_pass.LowerIntrin(fapi, "stackvm")
run_jit(fapi, lambda f: f(a)) run_jit(fapi, lambda f: f(a))
...@@ -60,7 +58,6 @@ def test_stack_vm_loop(): ...@@ -60,7 +58,6 @@ def test_stack_vm_loop():
stmt = ib.get() stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f): def check(f):
f(a) f(a)
...@@ -83,7 +80,6 @@ def test_stack_vm_cond(): ...@@ -83,7 +80,6 @@ def test_stack_vm_cond():
stmt = ib.get() stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab], 0, True) fapi = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab], 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi)
def check(f): def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
...@@ -103,7 +99,6 @@ def test_vm_parallel(): ...@@ -103,7 +99,6 @@ def test_vm_parallel():
A[i] = A[i] + 1 A[i] = A[i] + 1
stmt = ib.get() stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi)
def check(f): def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
......
...@@ -17,34 +17,6 @@ ...@@ -17,34 +17,6 @@
import tvm import tvm
from tvm import te from tvm import te
def test_storage_sync():
m = te.size_var('m')
l = te.size_var('l')
A = te.placeholder((m, l), name='A')
A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = te.create_schedule(A2.op)
xo, xi = s[A2].split(A2.op.axis[0], factor=8)
s[A2].bind(xo, te.thread_axis("blockIdx.x"))
s[A1].compute_at(s[A2], xo)
s[A1].set_scope("shared")
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
flist = tvm.tir.ir_pass.SplitHostDevice(f)
f = flist[1]
f = tvm.tir.ir_pass.ThreadSync(f, "shared")
body_list = tvm.tir.stmt_list(f.body.body.body.body)
assert(body_list[1].value.name == "tvm_storage_sync")
def test_coproc_sync(): def test_coproc_sync():
@tvm.register_func("tvm.info.mem.global.cache") @tvm.register_func("tvm.info.mem.global.cache")
def meminfo_cache(): def meminfo_cache():
...@@ -133,6 +105,5 @@ def test_coproc_sync3(): ...@@ -133,6 +105,5 @@ def test_coproc_sync3():
if __name__ == "__main__": if __name__ == "__main__":
test_coproc_sync() test_coproc_sync()
test_storage_sync()
test_coproc_sync2() test_coproc_sync2()
test_coproc_sync3() test_coproc_sync3()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
def test_thread_storage_sync():
m = te.size_var('m')
l = te.size_var('l')
A = te.placeholder((m, l), name='A')
A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = te.create_schedule(A2.op)
xo, xi = s[A2].split(A2.op.axis[0], factor=8)
s[A2].bind(xo, te.thread_axis("blockIdx.x"))
s[A1].compute_at(s[A2], xo)
s[A1].set_scope("shared")
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
flist = tvm.tir.ir_pass.SplitHostDevice(f)
f = flist[1]
fname = f.name
mod = tvm.testing.LoweredFuncsToIRModule([f])
cuda_target = tvm.target.create("cuda")
mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target))
f = tvm.tir.transform.ThreadSync("shared")(mod)["main"]
body_list = tvm.tir.stmt_list(f.body.body.body.body)
assert(body_list[1].value.name == "tvm_storage_sync")
if __name__ == "__main__":
test_thread_storage_sync()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment