Unverified Commit 75e936e1 by Tianqi Chen Committed by GitHub

[REFACTOR][TIR] Migrate most of low-level build to use the Pass Manager. (#5225)

* [REFACTOR][TIR] Migrate most of low-level build to use the Pass Manager.

- SplitHostDevice
- ThreadSync
- BindDevice
- LowerThreadAllreduce
- Provide a temp fix for printing IRModule with PrimFunc before the formal text printer.

* Address comments, fix tests.

* Fix relay tests

* Explicit move
parent 9b274cbb
......@@ -47,19 +47,19 @@ enum class CallingConv : int {
*/
kDefault = 0,
/*!
* \brief PackedFunc that exposes a CPackedFunc signature.
*
* - Calling by PackedFunc calling convention.
* - Implementation: Expose a function with the CPackedFunc signature.
*/
kCPackedFunc = 1,
/*!
* \brief Device kernel launch
*
* - Call by PackedFunc calling convention.
* - Implementation: defined by device runtime(e.g. runtime/cuda)
*/
kDeviceKernelLaunch = 2,
/*!
* \brief PackedFunc that exposes a CPackedFunc signature.
*
* - Calling by PackedFunc calling convention.
* - Implementation: Expose a function with the CPackedFunc signature.
*/
kCPackedFunc = 3,
};
/*!
......
......@@ -324,6 +324,8 @@ class IRModule : public ObjectRef {
/*! \brief Declare the container type. */
using ContainerType = IRModuleNode;
// allow copy on write.
TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode);
};
/*!
......
......@@ -49,6 +49,16 @@ struct ExprDeepEqual {
public:
TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
};
/*!
* \brief Find undefined vars in the statment.
* \param stmt The function to be checked.
* \param defs The vars that is defined.
* \return Array of undefined vars.
*/
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_ANALYSIS_H_
......@@ -407,56 +407,6 @@ LoweredFunc MakeAPI(Stmt body,
bool is_restricted);
/*!
* \brief Bind the device type of host function to be device_type.
* \param func The function to be binded.
* \param device_type The device type to be binded.
* \return The binded function.
*/
LoweredFunc BindDeviceType(LoweredFunc func,
int device_type);
/*!
* \brief Find undefined vars in the statment.
* \param stmt The function to be checked.
* \param defs The vars that is defined.
* \return Array of undefined vars.
*/
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
/*!
* \brief Split the function into a host function and device functions.
* \param func The function to be splitted.
*
* \return Array of functions, the first one is host function,
* the others are device functions.
*/
Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
/*!
* \brief Insert sync between parallel read/write of shared buffers.
*
* \param stmt The stmt to be trasnformed.
* \param storage_scope The storage scope considered.
*/
LoweredFunc ThreadSync(LoweredFunc stmt, std::string storage_scope);
/*!
* \brief Lower cross thread alleduce in the stmt.
* \param f The device function to be lowered.
* \param warp_size the size of warp where no sync is needed.
* \return Transformed function.
*/
LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
/*!
* \brief Lower warp memory in stmt.
* \param f The device function to be lowered.
* \param warp_size the size of warp where no sync is needed.
* this function will only take in effect if warp_size is bigger than one.
* \return Transformed function.
*/
LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);
/*!
* \brief Remap the thread axis
*
* This can be used to get equivalent program which uses
......@@ -471,26 +421,6 @@ LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);
LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
/*!
* \brief Lower packed function call.
* \param f The function to be lowered.
* \return Transformed function.
*/
LoweredFunc LowerTVMBuiltin(LoweredFunc f);
/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
* to avoid pointer casting in backend when possible.
*
* \note implemeneted in storage_rewrite.cc
* \param f The function to be trasnformed
* \return Transformed function.
*/
LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
......@@ -514,14 +444,6 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f);
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
/*!
* \brief Infer the TensorCore fragment infomation using tensor intrinsics
*
* \param f The device function to be lowered.
* \return Transformed function.
*/
LoweredFunc InferFragment(LoweredFunc f);
/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
* In the case that tgt is cuda, if not all workload is bound with
......
......@@ -59,6 +59,21 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const tvm::Array<tvm::PrimExpr>& required);
/*!
* \brief Bind the device type ofthe function to be
* the device_type specified in the target attribute.
*
* \return The pass.
*/
TVM_DLL Pass BindDeviceType();
/*!
* \brief Split the function into a host function and device functions.
*
* \return The pass.
*/
TVM_DLL Pass SplitHostDevice();
/*!
* \brief skip assert stmt.
*
* \return The pass.
......
......@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""The build utils in python.
This module provides the functions to transform schedule to
......@@ -25,6 +27,7 @@ import tvm.tir
from tvm.runtime import ndarray
from tvm.ir import container
from tvm.ir import CallingConv
from tvm.target import codegen, BuildConfig
from tvm.tir import ir_pass
from tvm.tir.stmt import LoweredFunc
......@@ -222,75 +225,59 @@ def _build_for_device(flist, target, target_host):
mdev : tvm.module
A module that contains device code.
"""
@tvm.tir.transform.prim_func_pass(opt_level=0)
class BindTarget:
def __init__(self, target):
self.target = target
# pylint: disable=unused-argument
def transform_function(self, func, mod, ctx):
return func.with_attr("target", self.target)
target = _target.create(target)
target_host = _target.create(target_host)
device_type = ndarray.context(target.target_name, 0).device_type
fhost = []
fdevice = []
for func in flist:
if not ir_pass.VerifyMemory(func, device_type):
raise ValueError(
"Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name)
if func.func_type == LoweredFunc.MixedFunc:
if BuildConfig.current().detect_global_barrier:
func = ir_pass.ThreadSync(func, "global")
func = ir_pass.ThreadSync(func, "shared")
func = ir_pass.ThreadSync(func, "warp")
func = ir_pass.InferFragment(func)
warp_size = target.thread_warp_size
func = ir_pass.LowerThreadAllreduce(func, warp_size)
fsplits = list(ir_pass.SplitHostDevice(func))
fhost.append(fsplits[0])
for x in fsplits[1:]:
fdevice.append(x)
elif func.func_type == LoweredFunc.HostFunc:
fhost.append(func)
elif func.func_type == LoweredFunc.DeviceFunc:
fdevice.append(func)
else:
raise ValueError("unknown function type %d" % func.func_type)
if "gpu" in target.keys and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do "
"bind?" % target)
fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
mod_mixed = tvm.testing.LoweredFuncsToIRModule(flist)
opt_mixed = [tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))]
if BuildConfig.current().detect_global_barrier:
opt_mixed += [tvm.tir.transform.ThreadSync("global")]
opt_mixed += [tvm.tir.transform.ThreadSync("shared"),
tvm.tir.transform.ThreadSync("warp"),
tvm.tir.transform.InferFragment(),
tvm.tir.transform.LowerThreadAllreduce(),
tvm.tir.transform.BindDeviceType(),
tvm.tir.transform.SplitHostDevice()]
mod_mixed = tvm.ir.transform.Sequential(opt_mixed)(mod_mixed)
if device_type == ndarray.cpu(0).device_type and target_host == target:
assert not fdevice
target_host = _target.create(target_host)
# device optimizations
mod_dev = tvm.testing.LoweredFuncsToIRModule(fdevice)
opt_device = tvm.ir.transform.Sequential(
[BindTarget(target),
[tvm.tir.transform.Filter(
lambda f: "calling_conv" in f.attrs and
f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH),
tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin()])
mod_dev = opt_device(mod_dev)
mod_dev = opt_device(mod_mixed)
# host optimizations
mod_host = tvm.testing.LoweredFuncsToIRModule(fhost)
opt_host = tvm.ir.transform.Sequential(
[BindTarget(target_host),
[tvm.tir.transform.Filter(
lambda f: "calling_conv" not in f.attrs or
f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH),
tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin(),
tvm.tir.transform.CombineContextCall()])
mod_host = opt_host(mod_host)
mod_host = opt_host(mod_mixed)
if device_type == ndarray.cpu(0).device_type and target_host == target:
assert len(mod_dev.functions) == 0
if "gpu" in target.keys and len(mod_dev.functions) == 0:
warnings.warn(
"Specified target %s, but cannot find device code, did you do "
"bind?" % target)
rt_mod_dev = codegen.build_module(mod_dev, target) if fdevice else None
rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None
return mod_host, rt_mod_dev
......
......@@ -23,7 +23,7 @@ from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
from .function import BaseFunc
from .function import CallingConv, BaseFunc
from .adt import Constructor, TypeData
from .module import IRModule
from .attrs import Attrs, DictAttrs, make_node
......
......@@ -15,10 +15,18 @@
# specific language governing permissions and limitations
# under the License.
"""Function defintiions."""
from enum import IntEnum
from .expr import RelayExpr
from . import _ffi_api
class CallingConv(IntEnum):
"""Possible kinds of calling conventions."""
DEFAULT = 0
C_PACKED_FUNC = 1
DEVICE_KERNEL_LAUNCH = 2
class BaseFunc(RelayExpr):
"""Base class of all functions."""
@property
......
......@@ -60,7 +60,6 @@ class IRModule(Node):
type_definitions = mapped_type_defs
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)
def __setitem__(self, var, val):
"""Add a mapping to the module.
......
......@@ -16,6 +16,7 @@
# under the License.
"""TIR specific function pass support."""
import inspect
import types
import functools
import tvm._ffi
......@@ -142,7 +143,7 @@ def prim_func_pass(pass_func=None, opt_level=None, name=None, required=None):
return _wrap_class_function_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _ffi_api.MakeFunctionPass(pass_arg, info)
return _ffi_api.CreatePrimFuncPass(pass_arg, info)
if pass_func:
return create_function_pass(pass_func)
......
......@@ -17,6 +17,70 @@
"""Wrapping existing transformations."""
# pylint: disable=invalid-name
from . import _ffi_api
from . import function_pass as _fpass
def Apply(ftransform):
"""Apply ftransform to each function in the Module.
This function is a thin wrapper around tvm.tir.transform.prim_func_pass
Parameters
----------
ftransform: tvm.tir.PrimFunc -> tvm.tir.PrimFunc
The transformation pass.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
# pylint: disable=unused-argument
def _transform(func, mod, ctx):
return ftransform(func)
return _fpass.prim_func_pass(_transform, opt_level=0)
def Filter(fcond):
"""Filter functions by the calling convention attribute.
Parameters
----------
fcond : tvm.tir.PrimFunc -> bool
The condition of the filtering.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
# pylint: disable=unused-argument
def _transform(func, mod, ctx):
return func if fcond(func) else None
return _fpass.prim_func_pass(_transform, opt_level=0)
def BindDeviceType():
"""Bind the device type of the function to be
the device_type specified in the target attribute.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.BindDeviceType()
def SplitHostDevice():
"""Split the function into a host function and device functions.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.SplitHostDevice()
def SkipAssert():
......
......@@ -185,75 +185,50 @@ transform::Pass BindTarget(Target target) {
}
template<typename FCond>
transform::Pass FilterBy(FCond fcond) {
auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
if (fcond(f)) {
return f;
} else {
return tir::PrimFunc(nullptr);
}
};
return tir::transform::CreatePrimFuncPass(fpass, 0, "FilterBy", {});
}
std::pair<IRModule, IRModule>
split_dev_host_funcs(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
std::unordered_set<std::string> all_names;
for (const auto& x : funcs) {
CHECK(all_names.count(x->name) == 0)
<< "Duplicate function name " << x->name;
all_names.insert(x->name);
}
Array<LoweredFunc> fhost;
Array<LoweredFunc> fdevice;
for (const auto& x : funcs) {
CHECK(tir::VerifyMemory(x, target->device_type))
<< "Direct host side access to device memory is detected in "
<< x->func_name() << ". Did you forget to bind?";
if (x->func_type == tir::kMixedFunc) {
auto func = x;
if (config->detect_global_barrier) {
func = tir::ThreadSync(func, "global");
}
func = tir::ThreadSync(func, "shared");
func = tir::ThreadSync(func, "warp");
func = tir::InferFragment(func);
func = tir::LowerThreadAllreduce(func, target->thread_warp_size);
auto fsplits = tir::SplitHostDevice(func);
fhost.push_back(fsplits[0]);
for (auto f = fsplits.begin() + 1; f != fsplits.end(); ++f) {
fdevice.push_back(*f);
}
} else if (x->func_type == tir::kHostFunc) {
fhost.push_back(x);
} else if (x->func_type == tir::kDeviceFunc) {
fdevice.push_back(x);
} else {
LOG(FATAL) << "unknown function type " << x->func_type;
}
}
auto keys = target->keys();
bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
if (target_is_gpu && fdevice.size() == 0) {
LOG(WARNING) << "Specified target "
<< target->str()
<< " but cannot find device code. Did you forget to bind?";
}
IRModule mod_mixed = codegen::ToIRModule(funcs);
if (target->device_type == target::llvm()->device_type &&
target_host == target) {
CHECK(fdevice.empty()) << "No device code should be generated when target "
<< "and host_target are both llvm target."
<< "\n";
}
for (size_t i = 0; i < fhost.size(); ++i) {
auto func = fhost[i];
func = tir::BindDeviceType(func, target->device_type);
fhost.Set(i, func);
Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target)};
if (config->detect_global_barrier) {
mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
}
mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
mixed_pass_list.push_back(tir::transform::InferFragment());
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
mixed_pass_list.push_back(tir::transform::BindDeviceType());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
auto opt_mixed = transform::Sequential(mixed_pass_list);
mod_mixed = opt_mixed(std::move(mod_mixed));
// host pipeline
auto mhost = codegen::ToIRModule(fhost);
auto host_pass_list = {
FilterBy([](const tir::PrimFunc& f) {
int64_t value = f->GetAttr<Integer>(tvm::attr::kCallingConv, 0)->value;
return value != static_cast<int>(CallingConv::kDeviceKernelLaunch);
}),
BindTarget(target_host),
tir::transform::LowerTVMBuiltin(),
tir::transform::LowerIntrin(),
......@@ -261,18 +236,38 @@ split_dev_host_funcs(const Array<LoweredFunc>& funcs,
tir::transform::CombineContextCall(),
};
auto opt_host = transform::Sequential(host_pass_list);
mhost = opt_host(mhost);
auto mhost = opt_host(mod_mixed);
// device pipeline
auto mdevice = codegen::ToIRModule(fdevice);
auto device_pass_list = {
FilterBy([](const tir::PrimFunc& f) {
int64_t value = f->GetAttr<Integer>(tvm::attr::kCallingConv, 0)->value;
return value == static_cast<int>(CallingConv::kDeviceKernelLaunch);
}),
BindTarget(target),
tir::transform::LowerWarpMemory(),
tir::transform::LowerIntrin(),
tir::transform::LowerDeviceStorageAccessInfo(),
};
auto opt_device = transform::Sequential(device_pass_list);
mdevice = opt_device(mdevice);
auto mdevice = opt_device(mod_mixed);
// some final misc checks.
auto keys = target->keys();
bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
if (target_is_gpu && mdevice->functions.size() == 0) {
LOG(WARNING) << "Specified target "
<< target->str()
<< " but cannot find device code. Did you forget to bind?";
}
if (target->device_type == target::llvm()->device_type &&
target_host == target) {
CHECK(mdevice->functions.empty())
<< "No device code should be generated when target "
<< "and host_target are both llvm target."
<< "\n";
}
return {mhost, mdevice};
}
......
......@@ -34,6 +34,7 @@
*/
#include <tvm/ir/type_functor.h>
#include <tvm/ir/module.h>
#include <tvm/tir/function.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "doc.h"
......@@ -434,6 +435,10 @@ class RelayTextPrinter :
Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) {
if (auto* n = base_func.as<relay::FunctionNode>()) {
return PrintFunc(prefix, GetRef<relay::Function>(n));
} else if (auto* n = base_func.as<tir::PrimFuncNode>()) {
std::ostringstream os;
os << GetRef<tir::PrimFunc>(n);
return Doc::RawText(os.str());
} else {
// def @xyz = meta['ExternalFunc'][id]
Doc doc;
......@@ -455,8 +460,9 @@ class RelayTextPrinter :
}
// functions
for (const auto& kv : mod->functions) {
dg_ = DependencyGraph::Create(&arena_, kv.second);
if (kv.second.as<relay::FunctionNode>()) {
dg_ = DependencyGraph::Create(&arena_, kv.second);
}
if (counter++ != 0) {
doc << Doc::NewLine();
}
......
......@@ -50,9 +50,10 @@ tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) {
Map<tir::Var, PrimExpr> remap_vars;
for (auto var : from->args) {
if (from->handle_data_type.count(var)) {
auto it = from->handle_data_type.find(var);
if (it != from->handle_data_type.end()) {
tir::Var new_var(var->name_hint,
PointerType(PrimType(var->dtype)));
PointerType(PrimType((*it).second->dtype)));
args.push_back(new_var);
remap_vars.Set(var, new_var);
} else {
......
......@@ -24,6 +24,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <memory>
#include <unordered_map>
#include "codegen_cpu.h"
......
......@@ -108,8 +108,13 @@ IRModule PrimFuncPassNode::operator()(const IRModule& mod,
updates.push_back({it.first, updated_func});
}
}
// automatic removal of None
for (const auto& pair : updates) {
updated_mod->Add(pair.first, pair.second, true);
if (pair.second.defined()) {
updated_mod->Add(pair.first, pair.second, true);
} else {
updated_mod->Remove(pair.first);
}
}
pass_ctx.Trace(updated_mod, pass_info, false);
return updated_mod;
......
......@@ -128,10 +128,7 @@ REGISTER_PASS(VectorizeLoop);
REGISTER_PASS(SkipVectorize);
REGISTER_PASS(UnrollLoop);
REGISTER_PASS(InjectCopyIntrin);
REGISTER_PASS(ThreadSync);
REGISTER_PASS(MakeAPI);
REGISTER_PASS(BindDeviceType);
REGISTER_PASS(SplitHostDevice);
REGISTER_PASS(StorageRewrite);
REGISTER_PASS(CoProcSync);
REGISTER_PASS(LowerStorageAccessInfo);
......@@ -141,7 +138,6 @@ REGISTER_PASS(InjectDoubleBuffer);
REGISTER_PASS(LoopPartition);
REGISTER_PASS(RemoveNoOp);
REGISTER_PASS(LiftAttrScope);
REGISTER_PASS(LowerThreadAllreduce);
REGISTER_PASS(RemapThreadAxis);
REGISTER_PASS(LowerCustomDatatypes);
REGISTER_PASS(VerifyMemory);
......@@ -150,7 +146,6 @@ REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer);
REGISTER_PASS(HoistIfThenElse);
REGISTER_PASS(InferFragment)
REGISTER_PASS(NarrowDataType);
} // namespace tir
} // namespace tvm
......@@ -218,69 +218,6 @@ LoweredFunc MakeAPI(Stmt body,
return f;
}
class DeviceTypeBinder: public StmtExprMutator {
public:
explicit DeviceTypeBinder(int device_type)
: device_type_(device_type) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::device_context_type) {
if (const VarNode* var = op->value.as<VarNode>()) {
var_ = var;
PrimExpr value = make_const(op->value.dtype(), device_type_);
Stmt body = StmtExprMutator::VisitStmt_(op);
var_ = nullptr;
std::ostringstream os;
os << "device_type need to be " << device_type_;
return AssertStmtNode::make(op->value == value, os.str(), body);
}
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const IfThenElseNode* op) final {
// eager simplify if guard.
Stmt res = StmtExprMutator::VisitStmt_(op);
op = res.as<IfThenElseNode>();
if (is_zero(op->condition)) {
if (op->else_case.defined()) return op->else_case;
return EvaluateNode::make(0);
}
if (is_one(op->condition)) {
return op->then_case;
}
return res;
}
PrimExpr VisitExpr_(const NENode* op) final {
// eager check NE for device check
PrimExpr res = StmtExprMutator::VisitExpr_(op);
op = res.as<NENode>();
if (tir::ExprDeepEqual()(op->a, op->b)) {
return make_const(op->dtype, false);
}
return res;
}
PrimExpr VisitExpr_(const VarNode* op) final {
if (op == var_) {
return make_const(op->dtype, device_type_);
} else {
return GetRef<PrimExpr>(op);
}
}
public:
const VarNode* var_{nullptr};
int device_type_;
};
LoweredFunc BindDeviceType(LoweredFunc f,
int device_type) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = DeviceTypeBinder(device_type)(n->body);
return LoweredFunc(n);
}
} // namespace tir
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file bind_device_type.cc
* \brief Bind the device type according to the target field.
*/
#include <tvm/ir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/analysis.h>
#include <tvm/target/target.h>
#include <tvm/runtime/registry.h>
namespace tvm {
namespace tir {
class DeviceTypeBinder: public StmtExprMutator {
public:
explicit DeviceTypeBinder(int device_type)
: device_type_(device_type) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::device_context_type) {
if (const VarNode* var = op->value.as<VarNode>()) {
var_ = var;
PrimExpr value = make_const(op->value.dtype(), device_type_);
Stmt body = StmtExprMutator::VisitStmt_(op);
var_ = nullptr;
std::ostringstream os;
os << "device_type need to be " << device_type_;
return AssertStmtNode::make(op->value == value, os.str(), body);
}
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const IfThenElseNode* op) final {
// eager simplify if guard.
Stmt res = StmtExprMutator::VisitStmt_(op);
op = res.as<IfThenElseNode>();
if (is_zero(op->condition)) {
if (op->else_case.defined()) return op->else_case;
return EvaluateNode::make(0);
}
if (is_one(op->condition)) {
return op->then_case;
}
return res;
}
PrimExpr VisitExpr_(const NENode* op) final {
// eager check NE for device check
PrimExpr res = StmtExprMutator::VisitExpr_(op);
op = res.as<NENode>();
if (tir::ExprDeepEqual()(op->a, op->b)) {
return make_const(op->dtype, false);
}
return res;
}
PrimExpr VisitExpr_(const VarNode* op) final {
if (op == var_) {
return make_const(op->dtype, device_type_);
} else {
return GetRef<PrimExpr>(op);
}
}
public:
const VarNode* var_{nullptr};
int device_type_;
};
namespace transform {
Pass BindDeviceType() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "BindDeviceType: Require the target attribute";
n->body = DeviceTypeBinder(target->device_type)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.BindDeviceType", {});
}
TVM_REGISTER_GLOBAL("tir.transform.BindDeviceType")
.set_body_typed(BindDeviceType);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -340,14 +340,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
std::unordered_map<const VarNode *, Stmt> alloc_remap_;
};
LoweredFunc
LowerThreadAllreduce(LoweredFunc f, int warp_size) {
CHECK_NE(f->func_type, kHostFunc);
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = ThreadAllreduceBuilder(warp_size)(n->body);
return LoweredFunc(n);
}
namespace transform {
Pass LowerThreadAllreduce() {
......@@ -356,10 +348,6 @@ Pass LowerThreadAllreduce() {
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "LowerThreadAllreduce: Require the target attribute";
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
CHECK(calling_conv.defined() &&
calling_conv->value == static_cast<int>(CallingConv::kDeviceKernelLaunch))
<< "LowerThreadAllreeduce: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
n->body = ThreadAllreduceBuilder(target->thread_warp_size)(n->body);
return f;
};
......
......@@ -21,18 +21,22 @@
* \file split_host_device.cc
* \brief Split device function from host.
*/
#include <tvm/ir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/runtime/module.h>
#include <tvm/target/target.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <unordered_map>
namespace tvm {
namespace tir {
// use/def analysis, also delete unreferenced lets
class IRUseDefAnalysis : public StmtExprMutator {
class VarUseDefAnalysis : public StmtExprMutator {
public:
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
......@@ -156,8 +160,27 @@ class IRUseDefAnalysis : public StmtExprMutator {
std::unordered_map<const VarNode*, int> def_count_;
};
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
VarUseDefAnalysis m;
for (Var arg : args) {
m.use_count_[arg.get()] = 0;
}
m(stmt);
return m.undefined_;
}
class HostDeviceSplitter : public StmtMutator {
public:
explicit HostDeviceSplitter(IRModuleNode* device_mod,
Target device_target,
std::string name_prefix)
: device_mod_(device_mod),
device_target_(device_target),
name_prefix_(name_prefix) {
}
Stmt VisitStmt_(const AllocateNode* op) final {
handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0);
return StmtMutator::VisitStmt_(op);
......@@ -172,86 +195,128 @@ class HostDeviceSplitter : public StmtMutator {
return StmtMutator::VisitStmt_(op);
}
Array<LoweredFunc> Split(LoweredFunc f) {
CHECK_EQ(f->func_type, kMixedFunc);
for (auto kv : f->handle_data_type) {
handle_data_type_[kv.first.get()] = kv.second;
}
name_ = f->name;
ObjectPtr<LoweredFuncNode> n =
make_object<LoweredFuncNode>(*f.operator->());
n->body = operator()(f->body);
n->func_type = kHostFunc;
Array<LoweredFunc> ret{LoweredFunc(n)};
for (LoweredFunc x : device_funcs_) {
ret.push_back(x);
}
return ret;
}
private:
Stmt SplitDeviceFunc(Stmt body) {
std::ostringstream os;
os << name_ << "_kernel" << device_funcs_.size();
ObjectPtr<LoweredFuncNode> n = make_object<LoweredFuncNode>();
os << name_prefix_ << "_kernel" << device_func_counter_++;
std::string kernel_symbol = os.str();
// isolate the device function.
IRUseDefAnalysis m;
VarUseDefAnalysis m;
m.visit_thread_extent_ = false;
n->body = m(std::move(body));
n->name = os.str();
n->func_type = kDeviceFunc;
n->thread_axis = m.thread_axis_;
body = m(std::move(body));
Array<Var> params;
Array<PrimExpr> arguments;
Map<tir::Var, PrimExpr> remap_vars;
// Strictly order the arguments: Var pointers, positional arguments.
for (Var v : m.undefined_) {
if (v.dtype().is_handle()) {
n->args.push_back(v);
// mark handle data type.
auto it = handle_data_type_.find(v.get());
for (Var var : m.undefined_) {
if (var.dtype().is_handle()) {
// Create a new version of v.
auto it = handle_data_type_.find(var.get());
if (it != handle_data_type_.end()) {
n->handle_data_type.Set(v, it->second);
tir::Var new_var(var->name_hint,
PointerType(PrimType((*it).second->dtype)));
params.push_back(new_var);
remap_vars.Set(var, new_var);
} else {
params.push_back(var);
}
arguments.push_back(var);
}
}
for (Var v : m.undefined_) {
if (!v.dtype().is_handle()) {
n->args.push_back(v);
// positional arguments
for (Var var : m.undefined_) {
if (!var.dtype().is_handle()) {
params.push_back(var);
arguments.push_back(var);
}
}
LoweredFunc f_device(n);
PrimFunc device_func(params, Substitute(body, remap_vars));
device_func = WithAttr(std::move(device_func), tir::attr::kDeviceThreadAxis, m.thread_axis_);
device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv,
Integer(CallingConv::kDeviceKernelLaunch));
device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol,
runtime::String(kernel_symbol));
device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1));
device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_);
device_mod_->Add(GlobalVar(kernel_symbol), device_func);
// generate calls to the device function
Array<PrimExpr> call_args;
call_args.push_back(StringImmNode::make(f_device->name));
for (Var arg : n->args) {
call_args.push_back(StringImmNode::make(kernel_symbol));
for (PrimExpr arg : arguments) {
call_args.push_back(arg);
}
for (PrimExpr ext : m.thread_extent_) {
call_args.push_back(ext);
}
device_funcs_.emplace_back(f_device);
return EvaluateNode::make(CallNode::make(
DataType::Int(32), intrinsic::tvm_call_packed,
call_args, CallNode::Intrinsic));
}
// function name
std::string name_;
// the device functions
// target ir module
IRModuleNode* device_mod_;
// Device target
Target device_target_;
// function name hint
std::string name_prefix_;
// Number of device functions.
int device_func_counter_{0};
std::vector<LoweredFunc> device_funcs_;
std::unordered_map<const VarNode*, PrimExpr> handle_data_type_;
};
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
IRUseDefAnalysis m;
for (Var arg : args) {
m.use_count_[arg.get()] = 0;
}
m(stmt);
return m.undefined_;
PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "SplitHostDevice: Require the target attribute";
auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute";
HostDeviceSplitter splitter(
device_mod, target, static_cast<std::string>(global_symbol));
auto* n = func.CopyOnWrite();
n->body = splitter(std::move(n->body));
// set the host target to None.
func = WithAttr(std::move(func), tvm::attr::kTarget, Target(nullptr));
return std::move(func);
}
Array<LoweredFunc> SplitHostDevice(LoweredFunc func) {
return HostDeviceSplitter().Split(func);
namespace transform {
Pass SplitHostDevice() {
auto pass_func = [](IRModule m, PassContext ctx) {
IRModuleNode* mptr = m.CopyOnWrite();
std::vector<std::pair<GlobalVar, PrimFunc> > updates;
for (const auto& kv : mptr->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
PrimFunc func = GetRef<PrimFunc>(n);
auto updated_func = SplitHostDevice(std::move(func), mptr);
updates.push_back({kv.first, updated_func});
}
}
for (const auto& pair : updates) {
mptr->Add(pair.first, pair.second, true);
}
return m;
};
return tvm::transform::CreateModulePass(
pass_func, 0, "tir.SplitHostDevice", {});
}
TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice")
.set_body_typed(SplitHostDevice);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -218,26 +218,19 @@ Stmt InferFragment(Stmt stmt) {
return stmt;
}
LoweredFunc InferFragment(LoweredFunc f) {
CHECK_NE(f->func_type, kHostFunc);
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = InferFragment(f->body);
return LoweredFunc(n);
}
namespace transform {
Pass InferFragement() {
Pass InferFragment() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = InferFragment(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InferFragement", {});
return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {});
}
TVM_REGISTER_GLOBAL("tir.transform.InferFragement")
.set_body_typed(InferFragement);
TVM_REGISTER_GLOBAL("tir.transform.InferFragment")
.set_body_typed(InferFragment);
} // namespace transform
} // namespace tir
......
......@@ -374,13 +374,6 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
return ThreadSyncInserter(sync_scope, planner.syncs_inserted_)(std::move(stmt));
}
LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
CHECK_NE(f->func_type, kHostFunc);
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = ThreadSync(f->body, storage_scope);
return LoweredFunc(n);
}
namespace transform {
Pass ThreadSync(std::string storage_scope) {
......
......@@ -28,7 +28,7 @@ def test_loop_dependent_allocate():
s[AA].compute_at(s[C], s[C].op.axis[0])
# this line should fail due to IRUseDefAnalysis sees an allocate statement
# referencing undefined variable
tvm.lower(s, [A,C])
tvm.lower(s, [A, C])
if __name__ == "__main__":
test_loop_dependent_allocate()
......@@ -41,7 +41,9 @@ def test_double_buffer():
assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2
f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
f = tvm.tir.ir_pass.ThreadSync(f, "shared")
mod = tvm.testing.LoweredFuncsToIRModule([f])
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0]
def count_sync(op):
if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
......
......@@ -93,7 +93,10 @@ def test_flatten_double_buffer():
assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2
f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
f = tvm.tir.ir_pass.ThreadSync(f, "shared")
f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
mod = tvm.testing.LoweredFuncsToIRModule([f])
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0]
def count_sync(op):
if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
......
......@@ -33,16 +33,15 @@ def test_lower_warp_mem():
xo, xi = s[AA].split(s[AA].op.axis[0], 32)
s[AA].bind(xi, tx)
f = tvm.lower(s, [A, B])
fhost, fdevice = tvm.tir.ir_pass.SplitHostDevice(f)
# temp adapter to convert loweredFunc to IRModule
# to test passes in the new style.
fname = fdevice.name
mod = tvm.testing.LoweredFuncsToIRModule([fdevice])
cuda_target = tvm.target.create("cuda")
assert cuda_target.thread_warp_size == 32
mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target))
f = tvm.lower(s, [A, B], name="f")
mod = tvm.testing.LoweredFuncsToIRModule([f])
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
mod = tvm.IRModule.from_expr(fdevice)
fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["main"]
assert(fdevice.body.body.value.value == "local")
assert(fdevice.body.body.body.extents[0].value == 2)
......
......@@ -38,13 +38,13 @@ def test_thread_storage_sync():
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
flist = tvm.tir.ir_pass.SplitHostDevice(f)
f = flist[1]
fname = f.name
mod = tvm.testing.LoweredFuncsToIRModule([f])
cuda_target = tvm.target.create("cuda")
mod = tvm.testing.LoweredFuncsToIRModule([f])
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
mod = tvm.IRModule.from_expr(fdevice)
cuda_target = tvm.target.create("cuda")
mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target))
f = tvm.tir.transform.ThreadSync("shared")(mod)["main"]
body_list = tvm.tir.stmt_list(f.body.body.body.body)
assert(body_list[1].value.name == "tvm_storage_sync")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment