Commit 44bffdb3 by Tianqi Chen Committed by GitHub

[REFACTOR][TIR] Migrate low-level pass functions to Pass Manager, (#5213)

- Migrate LowerTVMBultin
- Migrate inferFragment, LowerThreadAllreduce
- Migrate ThreadSync
- Refactor target::Build to directly take IRModule.
- Remove un-used legacy functions.
parent 88d2f34b
......@@ -57,20 +57,6 @@ TVM_DLL Array<tir::LoweredFunc> lower(
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
const BuildConfig& config);
/*!
* \brief Split host/device function and running necessary pass before build
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \return The Array<Array<LoweredFunc>> with 2 elements. First is host function Array,
second is device function array
*/
TVM_DLL Array<Array<tir::LoweredFunc> > split_dev_host_funcs(
const Array<tir::LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config);
/*!
* \brief Build a device and host module for a specific target from an array of lowered functions.
......
......@@ -25,6 +25,7 @@
#define TVM_TARGET_CODEGEN_H_
#include <tvm/runtime/packed_func.h>
#include <tvm/ir/module.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/target/target.h>
......@@ -41,15 +42,24 @@ using runtime::TVMArgs;
using runtime::TVMRetValue;
/*!
* \brief Temporary backward compatible function to convert a list
* of LoweredFunc to a IRModule of PrimfFuncs
* \param funcs The input lowered function.
* \return The IRModule.
*
* \note This function is only used for code refactor and will be
* removed once the refactor completes.
*/
IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs);
/*!
* \brief Build a module from array of lowered function.
* \param funcs The functions to be built.
* \param mod The Module to be built
* \param target The target to be built.
* \return The builded module.
*
* \note Calls global API function "_codegen_build_" + target
* \return The result runtime::Module.
*/
runtime::Module Build(const Array<tir::LoweredFunc>& funcs,
const std::string& target);
runtime::Module Build(IRModule mod, const Target& target);
/*!
* \brief Pack imported device library to a C file.
* Compile the C file and link with the host library
......
......@@ -477,12 +477,6 @@ LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
*/
LoweredFunc LowerTVMBuiltin(LoweredFunc f);
/*!
* \brief Combine context function calls.
* \param f The host function to be lowered.
* \return Transformed function.
*/
LoweredFunc CombineContextCall(LoweredFunc f);
/*!
* \brief Rewrite the pointer content type of arguments,
......@@ -496,7 +490,6 @@ LoweredFunc CombineContextCall(LoweredFunc f);
*/
LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
......@@ -510,23 +503,6 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
PrimFunc PointerValueTypeRewrite(PrimFunc f);
/*!
* \brief Lower attached storage access information on device.
* Do this pass after all storage access analysis finish.
*
* \param func The device function to be lowered.
* \return Transformed function.
*/
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc func);
/*!
* \brief Lower intrinsic function calls.
* \param f The device function to be lowered.
* \param target The target device.
* \return Transformed function.
*/
LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target);
/*!
* \brief Lower custom datatypes.
*
* See tvm::datatypes::Registry for more information on adding custom datatypes.
......@@ -546,13 +522,6 @@ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
LoweredFunc InferFragment(LoweredFunc f);
/*!
* \brief skip assert stmt generation
* \param f The function to be transformed.
* \return Transformed function.
*/
LoweredFunc SkipAssert(LoweredFunc f);
/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
* In the case that tgt is cuda, if not all workload is bound with
......
......@@ -59,11 +59,40 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const tvm::Array<tvm::PrimExpr>& required);
/*!
* \brief Combine context calls in the host function.
* \brief skip assert stmt.
*
* \return The pass.
*/
TVM_DLL Pass CombineContextCall();
TVM_DLL Pass SkipAssert();
/*!
* \brief Insert sync between parallel read/write of shared buffers.
*
* \param storage_scope The storage scope considered.
* \return The pass.
*/
TVM_DLL Pass ThreadSync(std::string storage_scope);
/*!
* \brief Lower cross thread alleduce.
*
* \return The pass.
*/
TVM_DLL Pass LowerThreadAllreduce();
/*!
* \brief Infer the TensorCore fragment infomation using tensor intrinsics
*
* \return The pass.
*/
TVM_DLL Pass InferFragment();
/*!
* \brief Lower builtin intrinsics.
* \return The pass.
*/
TVM_DLL Pass LowerTVMBuiltin();
/*!
* \brief Lower the target specific function intrinsics in each of the function.
......@@ -73,6 +102,12 @@ TVM_DLL Pass CombineContextCall();
TVM_DLL Pass LowerIntrin();
/*!
* \brief Lower warp memory access to low-level device related function calls.
* \return The pass.
*/
TVM_DLL Pass LowerWarpMemory();
/*!
* \brief Lower attached storage access information on device.
*
* \note Run this pass after all storage access analysis finish.
......@@ -82,10 +117,11 @@ TVM_DLL Pass LowerIntrin();
TVM_DLL Pass LowerDeviceStorageAccessInfo();
/*!
* \brief Lower warp memory access to low-level device related function calls.
* \brief Combine context calls in the host function.
*
* \return The pass.
*/
TVM_DLL Pass LowerWarpMemory();
TVM_DLL Pass CombineContextCall();
/*!
......
......@@ -222,6 +222,15 @@ def _build_for_device(flist, target, target_host):
mdev : tvm.module
A module that contains device code.
"""
@tvm.tir.transform.prim_func_pass(opt_level=0)
class BindTarget:
def __init__(self, target):
self.target = target
# pylint: disable=unused-argument
def transform_function(self, func, mod, ctx):
return func.with_attr("target", self.target)
target = _target.create(target)
device_type = ndarray.context(target.target_name, 0).device_type
fhost = []
......@@ -250,30 +259,39 @@ def _build_for_device(flist, target, target_host):
else:
raise ValueError("unknown function type %d" % func.func_type)
for i, func in enumerate(fdevice):
warp_size = target.thread_warp_size
fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)
if "gpu" in target.keys and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do "
"bind?" % target)
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
if device_type == ndarray.cpu(0).device_type and target_host == target:
assert not fdevice
target_host = _target.create(target_host)
fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
fhost = [ir_pass.CombineContextCall(x) for x in fhost]
mdev = codegen.build_module(fdevice, str(target)) if fdevice else None
return fhost, mdev
# device optimizations
mod_dev = tvm.testing.LoweredFuncsToIRModule(fdevice)
opt_device = tvm.ir.transform.Sequential(
[BindTarget(target),
tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin()])
mod_dev = opt_device(mod_dev)
# host optimizations
mod_host = tvm.testing.LoweredFuncsToIRModule(fhost)
opt_host = tvm.ir.transform.Sequential(
[BindTarget(target_host),
tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin(),
tvm.tir.transform.CombineContextCall()])
mod_host = opt_host(mod_host)
rt_mod_dev = codegen.build_module(mod_dev, target) if fdevice else None
return mod_host, rt_mod_dev
def build(inputs,
......@@ -402,19 +420,19 @@ def build(inputs,
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
fhost_all = []
mod_host_all = tvm.IRModule({})
device_modules = []
for tar, flist in target_flist.items():
fhost, mdev = _build_for_device(flist, tar, target_host)
# Save the current lowered functions of the host and the device module.
fhost_all += fhost
mod_host, mdev = _build_for_device(flist, tar, target_host)
mod_host_all.update(mod_host)
device_modules.append(mdev)
# Generate a unified host module.
mhost = codegen.build_module(fhost_all, str(target_host))
rt_mod_host = codegen.build_module(mod_host_all, target_host)
# Import all modules.
for mdev in device_modules:
if mdev:
mhost.import_module(mdev)
return mhost
rt_mod_host.import_module(mdev)
return rt_mod_host
......@@ -17,15 +17,16 @@
# under the License.
"""Code generation related functions."""
from . import _ffi_api
from . import target as _tgt
def build_module(lowered_func, target):
"""Build lowered_func into Module.
def build_module(mod, target):
"""Build IRModule into Module.
Parameters
----------
lowered_func : LoweredFunc
The lowered function
mod : tvm.IRModule
The ir module.
target : str
The target module type.
......@@ -35,7 +36,8 @@ def build_module(lowered_func, target):
module : runtime.Module
The corressponding module.
"""
return _ffi_api.Build(lowered_func, target)
target = _tgt.create(target) if isinstance(target, str) else target
return _ffi_api.Build(mod, target)
def llvm_lookup_intrinsic_id(name):
......
......@@ -16,19 +16,78 @@
# under the License.
"""Wrapping existing transformations."""
# pylint: disable=invalid-name
from . import _ffi_api
def CombineContextCall():
"""Combine context calls in the host function.
def SkipAssert():
"""Skip assert stmt.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.CombineContextCall()
return _ffi_api.SkipAssert()
def ThreadSync(storage_scope):
""" Insert sync between parallel read/write of shared buffers.
Parameters
----------
storage_scope: str
The target storage scope.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.ThreadSync(storage_scope)
def LowerThreadAllreduce():
"""Lower cross thread alleduce.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.LowerThreadAllreduce()
def InferFragment():
""" Infer the TensorCore fragment infomation using tensor intrinsics.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.InferFragment()
def LowerWarpMemory():
"""Lower warp memory access to low-level device related function calls.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.LowerWarpMemory()
def LowerTVMBuiltin():
"""Lower tvm builtin intrinsics.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.LowerTVMBuiltin()
def LowerIntrin():
......@@ -57,15 +116,15 @@ def LowerDeviceStorageAccessInfo():
return _ffi_api.LowerDeviceStorageAccessInfo()
def LowerWarpMemory():
"""Lower warp memory access to low-level device related function calls.
def CombineContextCall():
"""Combine context calls in the host function.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.LowerWarpMemory()
return _ffi_api.CombineContextCall()
def NarrowDataType():
......
......@@ -24,6 +24,8 @@
#include <dmlc/thread_local.h>
#include <tvm/driver/driver_api.h>
#include <tvm/te/operation.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/target/codegen.h>
#include <tvm/runtime/registry.h>
......@@ -174,7 +176,17 @@ Array<LoweredFunc> lower(te::Schedule sch,
return Array<LoweredFunc>({ tir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) });
}
Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
transform::Pass BindTarget(Target target) {
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
return WithAttr(std::move(f), tvm::attr::kTarget, target);
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "BindTarget", {});
}
std::pair<IRModule, IRModule>
split_dev_host_funcs(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
......@@ -217,13 +229,6 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
}
}
for (size_t i = 0; i < fdevice.size(); i++) {
auto warp_size = target->thread_warp_size;
auto func = fdevice[i];
func = tir::LowerWarpMemory(fdevice[i], warp_size);
fdevice.Set(i, func);
}
auto keys = target->keys();
bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
if (target_is_gpu && fdevice.size() == 0) {
......@@ -232,11 +237,6 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
<< " but cannot find device code. Did you forget to bind?";
}
for (size_t i = 0; i < fdevice.size(); ++i) {
auto func = fdevice[i];
func = tir::LowerIntrin(func, target->target_name);
fdevice.Set(i, func);
}
if (target->device_type == target::llvm()->device_type &&
target_host == target) {
......@@ -245,40 +245,38 @@ Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
<< "\n";
}
for (size_t i = 0; i < fdevice.size(); ++i) {
auto func = fdevice[i];
func = tir::LowerDeviceStorageAccessInfo(func);
fdevice.Set(i, func);
}
for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i];
func = tir::BindDeviceType(func, target->device_type);
func = tir::LowerDeviceStorageAccessInfo(func);
func = tir::LowerTVMBuiltin(func);
fhost.Set(i, func);
}
for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i];
func = tir::LowerIntrin(func, target_host->target_name);
func = tir::LowerDeviceStorageAccessInfo(func);
func = tir::CombineContextCall(func);
fhost.Set(i, func);
}
return {fhost, fdevice};
// host pipeline
auto mhost = codegen::ToIRModule(fhost);
auto host_pass_list = {
BindTarget(target_host),
tir::transform::LowerTVMBuiltin(),
tir::transform::LowerIntrin(),
tir::transform::LowerDeviceStorageAccessInfo(),
tir::transform::CombineContextCall(),
};
auto opt_host = transform::Sequential(host_pass_list);
mhost = opt_host(mhost);
// device pipeline
auto mdevice = codegen::ToIRModule(fdevice);
auto device_pass_list = {
BindTarget(target),
tir::transform::LowerWarpMemory(),
tir::transform::LowerIntrin(),
tir::transform::LowerDeviceStorageAccessInfo(),
};
auto opt_device = transform::Sequential(device_pass_list);
mdevice = opt_device(mdevice);
return {mhost, mdevice};
}
// Create a module for a specific device (target). The lowered functions
// associated with the host is returned as well.
runtime::Module DeviceBuild(const Array<LoweredFunc>& fdevice,
const Target& target) {
if (!fdevice.empty()) {
return codegen::Build(fdevice, target->str());
} else {
return runtime::Module(nullptr);
}
}
// Build for heterogeneous execution.
runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs,
......@@ -301,20 +299,21 @@ runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs,
target_host_val = DefaultTargetHost(target_host_val);
}
IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>());
for (const auto& it : inputs) {
auto host_dev_funcs =
auto pair =
split_dev_host_funcs(it.second, it.first, target_host_val, config);
auto& fhost = host_dev_funcs[0];
auto& fdevice = host_dev_funcs[1];
// Get the module for a certain target.
runtime::Module mdev = DeviceBuild(fdevice, it.first);
for (const auto& it : fhost) {
fhost_all.push_back(it);
auto& mhost = pair.first;
auto& mdevice = pair.second;
mhost_all->Update(mhost);
if (mdevice->functions.size() != 0) {
device_modules.push_back(codegen::Build(mdevice, it.first));
}
device_modules.push_back(mdev);
}
runtime::Module mhost = codegen::Build(fhost_all, target_host_val->str());
runtime::Module mhost = codegen::Build(mhost_all, target_host_val);
// Import all modules
for (const auto& it : device_modules) {
if (it.operator->()) {
......
......@@ -26,6 +26,7 @@
#include <tvm/ir/module.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/function.h>
#include <tvm/runtime/container.h>
......@@ -42,18 +43,6 @@
namespace tvm {
namespace codegen {
// The new build function.
// adapt the old function to the new one
runtime::Module BuildForIRModule(const IRModule& module,
const Target& target) {
std::string build_f_name = "target.build." + target->target_name;
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
CHECK(bf != nullptr)
<< "target.build." << target << " is not enabled";
return (*bf)(module, target->str());
}
// convert legacy LoweredFunc to PrimFunc.
tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) {
// remap args to attach type annotations.
......@@ -97,24 +86,16 @@ IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs) {
return IRModule(functions);
}
runtime::Module Build(const Array<tir::LoweredFunc>& funcs,
const std::string& target) {
std::string mode = target;
size_t pos = mode.find(' ');
if (pos != std::string::npos) {
mode = mode.substr(0, pos);
}
Array<tir::LoweredFunc> transformed_funcs;
runtime::Module Build(IRModule mod, const Target& target) {
if (BuildConfig::Current()->disable_assert) {
for (const auto& x : funcs) {
auto func = tir::SkipAssert(x);
transformed_funcs.push_back(func);
}
mod = tir::transform::SkipAssert()(mod);
}
return BuildForIRModule(
transformed_funcs.size() != 0 ? ToIRModule(transformed_funcs) : ToIRModule(funcs),
Target::Create(target));
std::string build_f_name = "target.build." + target->target_name;
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
CHECK(bf != nullptr)
<< "target.build." << target << " is not enabled";
return (*bf)(mod, target->str());
}
/*! \brief Helper class to serialize module */
......@@ -300,13 +281,7 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod,
}
TVM_REGISTER_GLOBAL("target.Build")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsObjectRef<tir::LoweredFunc>()) {
*ret = Build({args[0]}, args[1]);
} else {
*ret = Build(args[0], args[1]);
}
});
.set_body_typed(Build);
TVM_REGISTER_GLOBAL("testing.LoweredFuncsToIRModule")
.set_body_typed(ToIRModule);
......
......@@ -135,7 +135,6 @@ REGISTER_PASS(SplitHostDevice);
REGISTER_PASS(StorageRewrite);
REGISTER_PASS(CoProcSync);
REGISTER_PASS(LowerStorageAccessInfo);
REGISTER_PASS(LowerDeviceStorageAccessInfo)
REGISTER_PASS(InjectVirtualThread);
REGISTER_PASS(InjectPrefetch);
REGISTER_PASS(InjectDoubleBuffer);
......@@ -143,12 +142,8 @@ REGISTER_PASS(LoopPartition);
REGISTER_PASS(RemoveNoOp);
REGISTER_PASS(LiftAttrScope);
REGISTER_PASS(LowerThreadAllreduce);
REGISTER_PASS(LowerWarpMemory);
REGISTER_PASS(RemapThreadAxis);
REGISTER_PASS(LowerIntrin);
REGISTER_PASS(LowerCustomDatatypes);
REGISTER_PASS(LowerTVMBuiltin);
REGISTER_PASS(CombineContextCall);
REGISTER_PASS(VerifyMemory);
REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope);
......
......@@ -109,18 +109,13 @@ class ContextCallCombiner final : public StmtExprMutator {
std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual> ctx_map_;
};
LoweredFunc CombineContextCall(LoweredFunc f) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = ContextCallCombiner().Combine(n->body);
return LoweredFunc(n);
}
namespace transform {
Pass CombineContextCall() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = ContextCallCombiner().Combine(n->body);
n->body = ContextCallCombiner().Combine(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {});
......
......@@ -142,11 +142,6 @@ Stmt LowerStorageAccessInfo(Stmt stmt) {
return StorageAccessInfoLower()(std::move(stmt));
}
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = LowerStorageAccessInfo(f->body);
return LoweredFunc(n);
}
namespace transform {
......
......@@ -283,16 +283,6 @@ Stmt LowerIntrinStmt(Stmt stmt, const std::string target_name) {
return IntrinInjecter(&analyzer, target_name)(std::move(stmt));
}
LoweredFunc
LowerIntrin(LoweredFunc f, const std::string& target) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
std::istringstream is(target);
std::string target_name;
is >> target_name;
n->body = LowerIntrinStmt(n->body, target_name);
return LoweredFunc(n);
}
namespace transform {
Pass LowerIntrin() {
......
......@@ -23,9 +23,14 @@
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/target/target.h>
#include <tvm/runtime/registry.h>
#include <unordered_set>
#include "ir_util.h"
#include "../pass/ir_util.h"
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
......@@ -342,5 +347,28 @@ LowerThreadAllreduce(LoweredFunc f, int warp_size) {
n->body = ThreadAllreduceBuilder(warp_size)(n->body);
return LoweredFunc(n);
}
namespace transform {
Pass LowerThreadAllreduce() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "LowerThreadAllreduce: Require the target attribute";
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "LowerThreadAllreeduce: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
n->body = ThreadAllreduceBuilder(target->thread_warp_size)(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {});
}
TVM_REGISTER_GLOBAL("tir.transform.LowerThreadAllreduce")
.set_body_typed(LowerThreadAllreduce);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -18,14 +18,18 @@
*/
/*!
* Lower TVM related buildin intrinsics such as packed call.
* \file lower_tvm_buildin.cc
* Lower TVM related builtin intrinsics such as packed call.
* \file tir/transforms/lower_tvm_buildin.cc
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/runtime/registry.h>
#include <unordered_set>
#include "ir_util.h"
#include "../pass/ir_util.h"
#include "../../arith/compute_expr.h"
namespace tvm {
......@@ -368,11 +372,20 @@ class BuiltinLower : public StmtExprMutator {
uint64_t max_arg_stack_{0};
};
LoweredFunc LowerTVMBuiltin(LoweredFunc f) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
namespace transform {
Pass LowerTVMBuiltin() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = BuiltinLower().Build(n->body);
return LoweredFunc(n);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {});
}
TVM_REGISTER_GLOBAL("tir.transform.LowerTVMBuiltin")
.set_body_typed(LowerTVMBuiltin);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -385,14 +385,6 @@ class WarpMemoryRewriter : private StmtMutator {
std::unordered_map<const VarNode*, Range> var_dom_;
};
LoweredFunc
LowerWarpMemory(LoweredFunc f, int warp_size) {
CHECK_EQ(f->func_type, kDeviceFunc);
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = WarpMemoryRewriter(warp_size).Rewrite(n->body);
return LoweredFunc(n);
}
namespace transform {
Pass LowerWarpMemory() {
......@@ -401,7 +393,7 @@ Pass LowerWarpMemory() {
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "LowerWarpMemory: Require the target attribute";
n->body = WarpMemoryRewriter(target->thread_warp_size).Rewrite(n->body);
n->body = WarpMemoryRewriter(target->thread_warp_size).Rewrite(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});
......
......@@ -19,7 +19,9 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/runtime/registry.h>
namespace tvm {
namespace tir {
......@@ -37,11 +39,21 @@ Stmt SkipAssert(Stmt stmt) {
return AssertSkipper()(std::move(stmt));
}
LoweredFunc SkipAssert(LoweredFunc f) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = SkipAssert(f->body);
return LoweredFunc(n);
namespace transform {
Pass SkipAssert() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = AssertSkipper()(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.SkipAssert", {});
}
TVM_REGISTER_GLOBAL("tir.transform.SkipAssert")
.set_body_typed(SkipAssert);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -23,11 +23,15 @@
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/runtime/registry.h>
#include <unordered_map>
#include <unordered_set>
#include "ir_util.h"
#include "storage_access.h"
#include "../pass/storage_access.h"
#include "../pass/ir_util.h"
#include "../../runtime/thread_storage_scope.h"
namespace tvm {
......@@ -221,5 +225,20 @@ LoweredFunc InferFragment(LoweredFunc f) {
return LoweredFunc(n);
}
namespace transform {
Pass InferFragement() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = InferFragment(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InferFragement", {});
}
TVM_REGISTER_GLOBAL("tir.transform.InferFragement")
.set_body_typed(InferFragement);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -18,16 +18,21 @@
*/
/*!
* \file storage_sync.cc
* \file thread_storage_sync.cc
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/runtime/registry.h>
#include <unordered_map>
#include <unordered_set>
#include "ir_util.h"
#include "storage_access.h"
#include "../pass/ir_util.h"
#include "../pass/storage_access.h"
#include "../../runtime/thread_storage_scope.h"
namespace tvm {
......@@ -376,5 +381,20 @@ LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
return LoweredFunc(n);
}
namespace transform {
Pass ThreadSync(std::string storage_scope) {
auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = ThreadSync(std::move(n->body), storage_scope);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.ThreadSync", {});
}
TVM_REGISTER_GLOBAL("tir.transform.ThreadSync")
.set_body_typed(ThreadSync);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -50,13 +50,12 @@ def test_dot():
k = te.reduce_axis((0, n), 'k')
C = te.compute((1,), lambda _: te.sum(A[k] * B[k], axis=k), name='C')
s = te.create_schedule(C.op)
fapi = lower(s, [A, B, C])
def verify(target):
if not tvm.runtime.enabled(target):
print("Target %s is not enabled" % target)
return
f = tvm.target.codegen.build_module(fapi, target)
f = tvm.driver.build(s, [A, B, C], target)
# verify
ctx = tvm.cpu(0)
a = tvm.nd.array(np.random.uniform(size=(nn,)).astype(A.dtype), ctx)
......
......@@ -39,8 +39,9 @@ def test_dltensor_compatible():
A[i + 1] = A[i] + 1
stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi)
f = tvm.target.codegen.build_module(fapi, "stackvm")
mod = tvm.testing.LoweredFuncsToIRModule([fapi])
mod = tvm.tir.transform.LowerTVMBuiltin()(mod)
f = tvm.target.codegen.build_module(mod, "stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
aview = MyTensorView(a)
f(aview)
......
......@@ -58,8 +58,7 @@ def test_dso_module_load():
tvm.tir.Load(dtype, Ab.data, i) + 1,
i + 1))
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi)
m = tvm.target.codegen.build_module(fapi, "llvm")
m = tvm.driver.build(fapi, target="llvm")
for name in names:
m.save(name)
......
......@@ -74,9 +74,8 @@ def test_add_pipeline():
binds = {A : Ab}
# BUILD and invoke the kernel.
f1 = tvm.lower(s, [A,B,C], name="fadd_pipeline")
fsplits = [x for x in tvm.tir.ir_pass.SplitHostDevice(f1)]
fsplits[0] = tvm.tir.ir_pass.LowerTVMBuiltin(fsplits[0])
mhost = tvm.target.codegen.build_module(fsplits[0], "c")
mhost = tvm.build(f1, target="c")
temp = util.tempdir()
path_dso = temp.relpath("temp.so")
mhost.export_library(path_dso)
......
......@@ -63,79 +63,27 @@ def test_add_pipeline():
s[D].bind(xi, te.thread_axis("threadIdx.x"))
s[D].bind(xo, te.thread_axis("blockIdx.x"))
# compile to IR
s = s.normalize()
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
Db = tvm.tir.decl_buffer(D.shape, D.dtype, name='D')
stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, D:Db}, 64)
stmt = tvm.tir.ir_pass.Simplify(stmt)
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Db], 0, True)
fsplits = [x for x in tvm.tir.ir_pass.SplitHostDevice(fapi)]
# lower the floordiv(use stackvm rules so it works for all targets)
fsplits = [tvm.tir.ir_pass.LowerIntrin(x, "stackvm") for x in fsplits]
fsplits[0] = tvm.tir.ir_pass.LowerTVMBuiltin(fsplits[0])
def check_target(device, host="stackvm"):
ctx = tvm.context(device, 0)
if not ctx.exist:
return
if not tvm.runtime.enabled(host):
return
mhost = tvm.target.codegen.build_module(fsplits[0], host)
mdev = tvm.target.codegen.build_module(fsplits[1:], device)
mhost.import_module(mdev)
code = mdev.get_source()
f = mhost.entry_func
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx)
d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
f(a, b, d)
tvm.testing.assert_allclose(
d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
def check_module_save(device, host="stackvm"):
ctx = tvm.context(device, 0)
if not ctx.exist:
return
if not tvm.runtime.enabled(host):
return
if device == "cuda":
fmt = "ptx"
elif device == "rocm":
fmt = "hsaco"
else:
fmt = device
mhost = tvm.target.codegen.build_module(fsplits[0], host)
mdev = tvm.target.codegen.build_module(fsplits[1:], device)
temp = util.tempdir()
mpath = temp.relpath("test.%s" % fmt)
mdev.save(mpath)
mdev2 = tvm.runtime.load_module(mpath)
mhost.import_module(mdev2)
mhost = tvm.driver.build(s, [A, B, D], target=device, target_host=host)
f = mhost.entry_func
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=()).astype(Bb.dtype), ctx)
d = tvm.nd.array(np.zeros(n, dtype=Db.dtype), ctx)
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=()).astype(B.dtype), ctx)
d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx)
f(a, b, d)
tvm.testing.assert_allclose(
d.asnumpy(), a.asnumpy() + b.asnumpy() + 1)
check_target("cuda", host="stackvm")
check_target("cuda", host="llvm")
check_module_save("cuda", host="stackvm")
check_target("nvptx", host="llvm")
check_target("vulkan", host="llvm")
check_module_save("vulkan", host="stackvm")
check_target("rocm", host="llvm")
check_module_save("rocm", host="llvm")
if __name__ == "__main__":
......
......@@ -33,8 +33,7 @@ def test_static_callback():
A[i] = A[i] + 1
stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi)
f = tvm.target.codegen.build_module(fapi, "llvm")
f = tvm.driver.build(fapi, target="llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
f(a)
......@@ -57,8 +56,7 @@ def test_static_init():
stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi)
f = tvm.target.codegen.build_module(fapi, "llvm")
f = tvm.driver.build(fapi, target="llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
......
......@@ -22,7 +22,7 @@ def run_jit(fapi, check):
for target in ["llvm", "stackvm"]:
if not tvm.runtime.enabled(target):
continue
f = tvm.target.codegen.build_module(fapi, target)
f = tvm.driver.build(fapi, target=target)
s = f.get_source()
check(f)
......@@ -37,8 +37,6 @@ def test_stack_vm_basic():
Ab = tvm.tir.decl_buffer((n, ), "float32")
stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi)
fapi = tvm.tir.ir_pass.LowerIntrin(fapi, "stackvm")
run_jit(fapi, lambda f: f(a))
......@@ -60,7 +58,6 @@ def test_stack_vm_loop():
stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f):
f(a)
......@@ -83,7 +80,6 @@ def test_stack_vm_cond():
stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab], 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi)
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
......@@ -103,7 +99,6 @@ def test_vm_parallel():
A[i] = A[i] + 1
stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi)
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
......
......@@ -17,34 +17,6 @@
import tvm
from tvm import te
def test_storage_sync():
m = te.size_var('m')
l = te.size_var('l')
A = te.placeholder((m, l), name='A')
A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = te.create_schedule(A2.op)
xo, xi = s[A2].split(A2.op.axis[0], factor=8)
s[A2].bind(xo, te.thread_axis("blockIdx.x"))
s[A1].compute_at(s[A2], xo)
s[A1].set_scope("shared")
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
flist = tvm.tir.ir_pass.SplitHostDevice(f)
f = flist[1]
f = tvm.tir.ir_pass.ThreadSync(f, "shared")
body_list = tvm.tir.stmt_list(f.body.body.body.body)
assert(body_list[1].value.name == "tvm_storage_sync")
def test_coproc_sync():
@tvm.register_func("tvm.info.mem.global.cache")
def meminfo_cache():
......@@ -133,6 +105,5 @@ def test_coproc_sync3():
if __name__ == "__main__":
test_coproc_sync()
test_storage_sync()
test_coproc_sync2()
test_coproc_sync3()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
def test_thread_storage_sync():
m = te.size_var('m')
l = te.size_var('l')
A = te.placeholder((m, l), name='A')
A1 = te.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = te.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = te.create_schedule(A2.op)
xo, xi = s[A2].split(A2.op.axis[0], factor=8)
s[A2].bind(xo, te.thread_axis("blockIdx.x"))
s[A1].compute_at(s[A2], xo)
s[A1].set_scope("shared")
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
flist = tvm.tir.ir_pass.SplitHostDevice(f)
f = flist[1]
fname = f.name
mod = tvm.testing.LoweredFuncsToIRModule([f])
cuda_target = tvm.target.create("cuda")
mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target))
f = tvm.tir.transform.ThreadSync("shared")(mod)["main"]
body_list = tvm.tir.stmt_list(f.body.body.body.body)
assert(body_list[1].value.name == "tvm_storage_sync")
if __name__ == "__main__":
test_thread_storage_sync()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment