Unverified Commit 2b6d69c6 by Tianqi Chen Committed by GitHub

[REFACTOR][TIR] Migrate Low-level Passes to Pass Manager (#5198)

* [TIR][TRANSFORM] Migrate LowerIntrin

* LowerDeviceStorageAccessInfo

* Migrate LowerWarpMemory
parent 03ff0cd0
......@@ -321,6 +321,9 @@ class IRModule : public ObjectRef {
* \return A Relay module.
*/
TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path);
/*! \brief Declare the container type. */
using ContainerType = IRModuleNode;
};
/*!
......
......@@ -59,11 +59,33 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const tvm::Array<tvm::PrimExpr>& required);
/*!
* \brief Create PrimFuncPass to combine context calls in the host function.
* \brief Combine context calls in the host function.
*
* \return The pass.
*/
Pass CombineContextCall();
TVM_DLL Pass CombineContextCall();
/*!
* \brief Lower the target specific function intrinsics in each of the function.
*
* \return The pass.
*/
TVM_DLL Pass LowerIntrin();
/*!
* \brief Lower attached storage access information on device.
*
* \note Run this pass after all storage access analysis finish.
*
* \return The pass.
*/
TVM_DLL Pass LowerDeviceStorageAccessInfo();
/*!
* \brief Lower warp memory access to low-level device related function calls.
* \return The pass.
*/
TVM_DLL Pass LowerWarpMemory();
} // namespace transform
} // namespace tir
......
......@@ -29,3 +29,40 @@ def CombineContextCall():
The result pass
"""
return _ffi_api.CombineContextCall()
def LowerIntrin():
"""Lower target specific intrinsic calls.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.LowerIntrin()
def LowerDeviceStorageAccessInfo():
"""Lower attached storage access information on device.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
Note
----
Run this pass after all storage access analysis finish.
"""
return _ffi_api.LowerDeviceStorageAccessInfo()
def LowerWarpMemory():
"""Lower warp memory access to low-level device related function calls.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.LowerWarpMemory()
......@@ -408,8 +408,8 @@ TVM_REGISTER_GLOBAL("ir.Module_Add")
bool update = args[3];
CHECK(val->IsInstance<RelayExprNode>());
if (val->IsInstance<relay::FunctionNode>()) {
mod->Add(var, Downcast<relay::Function>(val), update);
if (val->IsInstance<BaseFuncNode>()) {
mod->Add(var, Downcast<BaseFunc>(val), update);
} else if (val->IsInstance<GlobalVarNode>()) {
GlobalVar gv = Downcast<GlobalVar>(val);
auto mod_copy = IRModule(make_object<IRModuleNode>(*mod.operator->()));
......
......@@ -382,6 +382,7 @@ TVM_REGISTER_GLOBAL("transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0];
IRModule mod = args[1];
ObjectRef ref = args[1];
*ret = pass(mod);
});
......
......@@ -54,8 +54,6 @@ runtime::Module BuildForIRModule(const IRModule& module,
return (*bf)(module, target->str());
}
// convert legacy LoweredFunc to PrimFunc.
tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) {
// remap args to attach type annotations.
......
......@@ -235,116 +235,5 @@ StorageScope StorageAccessVisitor::GetScope(const VarNode* buf) const {
return it->second;
}
class StorageAccessInfoLower : public StmtExprMutator {
public:
Stmt VisitStmt_(const AllocateNode* op) final {
// Lower allocate to device allocate when needed.
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
// For special memory, remove allocate, or use head expr
auto it = storage_info_.find(op->buffer_var.get());
if (it != storage_info_.end() && it->second.info.defined()) {
const MemoryInfo& info = it->second.info;
++it->second.alloc_count;
CHECK_LE(it->second.alloc_count, 1)
<< "Double allocation of " << it->second.scope.to_string();
if (info->head_address.defined()) {
return AllocateNode::make(
op->buffer_var, op->dtype, op->extents, op->condition,
op->body, info->head_address, "nop");
}
return op->body;
} else {
return stmt;
}
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::storage_scope) {
const VarNode* buf = op->node.as<VarNode>();
StorageScope scope = StorageScope::make(op->value.as<StringImmNode>()->value);
StorageEntry e;
e.scope = scope;
if (scope.tag.length() != 0) {
e.info = GetMemoryInfo(op->value.as<StringImmNode>()->value);
CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
}
storage_info_[buf] = e;
return StmtExprMutator::VisitStmt_(op);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
return MakeAccessPtr(op);
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
private:
// tvm_access_ptr
PrimExpr MakeAccessPtr(const CallNode* op) {
// Specially handle the buffer packed intrinsic
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
Var buffer_var = Downcast<Var>(op->args[1]);
PrimExpr offset = op->args[2];
auto it = storage_info_.find(buffer);
if (it != storage_info_.end() && it->second.info.defined()) {
return MakeTaggedAccessPtr(
op->dtype, buffer_var, dtype, offset,
it->second.info);
}
CHECK(op->dtype.is_handle());
// Change to address_of
return AddressOffset(buffer_var, dtype, offset);
}
PrimExpr MakeTaggedAccessPtr(DataType ptr_type,
Var buffer_var,
DataType dtype,
PrimExpr offset,
const MemoryInfo& info) {
if (ptr_type.is_handle()) {
CHECK(info->head_address.defined())
<< buffer_var << " is not adddressable.";
return AddressOffset(buffer_var, dtype, offset);
}
int dtype_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(info->unit_bits % dtype_bits, 0);
return cast(ptr_type,
tir::Simplify(offset / make_const(
offset.dtype(), info->unit_bits / dtype_bits)));
}
// The storage entry.
struct StorageEntry {
// Whether it is tagged memory.
StorageScope scope;
// The memory info if any.
MemoryInfo info;
// Allocation counter
int alloc_count{0};
};
// The storage scope of each buffer
std::unordered_map<const VarNode*, StorageEntry> storage_info_;
};
Stmt LowerStorageAccessInfo(Stmt stmt) {
return StorageAccessInfoLower()(std::move(stmt));
}
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = LowerStorageAccessInfo(f->body);
return LoweredFunc(n);
}
} // namespace tir
} // namespace tvm
......@@ -126,7 +126,7 @@ Pass CombineContextCall() {
n->body = ContextCallCombiner().Combine(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "CombineContextCall", {});
return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {});
}
TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall")
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file lower_device_storage_access.cc
* \brief Lower the special device storage access.
*/
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/buffer.h>
#include <tvm/target/target_info.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/ir_pass.h>
#include "../pass/ir_util.h"
#include "../../runtime/thread_storage_scope.h"
namespace tvm {
namespace tir {
using runtime::StorageScope;
using runtime::StorageRank;
class StorageAccessInfoLower : public StmtExprMutator {
public:
Stmt VisitStmt_(const AllocateNode* op) final {
// Lower allocate to device allocate when needed.
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
// For special memory, remove allocate, or use head expr
auto it = storage_info_.find(op->buffer_var.get());
if (it != storage_info_.end() && it->second.info.defined()) {
const MemoryInfo& info = it->second.info;
++it->second.alloc_count;
CHECK_LE(it->second.alloc_count, 1)
<< "Double allocation of " << it->second.scope.to_string();
if (info->head_address.defined()) {
return AllocateNode::make(
op->buffer_var, op->dtype, op->extents, op->condition,
op->body, info->head_address, "nop");
}
return op->body;
} else {
return stmt;
}
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::storage_scope) {
const VarNode* buf = op->node.as<VarNode>();
StorageScope scope = StorageScope::make(op->value.as<StringImmNode>()->value);
StorageEntry e;
e.scope = scope;
if (scope.tag.length() != 0) {
e.info = GetMemoryInfo(op->value.as<StringImmNode>()->value);
CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
}
storage_info_[buf] = e;
return StmtExprMutator::VisitStmt_(op);
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
return MakeAccessPtr(op);
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
private:
// tvm_access_ptr
PrimExpr MakeAccessPtr(const CallNode* op) {
// Specially handle the buffer packed intrinsic
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
Var buffer_var = Downcast<Var>(op->args[1]);
PrimExpr offset = op->args[2];
auto it = storage_info_.find(buffer);
if (it != storage_info_.end() && it->second.info.defined()) {
return MakeTaggedAccessPtr(
op->dtype, buffer_var, dtype, offset,
it->second.info);
}
CHECK(op->dtype.is_handle());
// Change to address_of
return AddressOffset(buffer_var, dtype, offset);
}
PrimExpr MakeTaggedAccessPtr(DataType ptr_type,
Var buffer_var,
DataType dtype,
PrimExpr offset,
const MemoryInfo& info) {
if (ptr_type.is_handle()) {
CHECK(info->head_address.defined())
<< buffer_var << " is not adddressable.";
return AddressOffset(buffer_var, dtype, offset);
}
int dtype_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(info->unit_bits % dtype_bits, 0);
return cast(ptr_type,
tir::Simplify(offset / make_const(
offset.dtype(), info->unit_bits / dtype_bits)));
}
// The storage entry.
struct StorageEntry {
// Whether it is tagged memory.
StorageScope scope;
// The memory info if any.
MemoryInfo info;
// Allocation counter
int alloc_count{0};
};
// The storage scope of each buffer
std::unordered_map<const VarNode*, StorageEntry> storage_info_;
};
Stmt LowerStorageAccessInfo(Stmt stmt) {
return StorageAccessInfoLower()(std::move(stmt));
}
LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = LowerStorageAccessInfo(f->body);
return LoweredFunc(n);
}
namespace transform {
Pass LowerDeviceStorageAccessInfo() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = StorageAccessInfoLower()(std::move(n->body));
return f;
};
return CreatePrimFuncPass(
pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {});
}
TVM_REGISTER_GLOBAL("tir.transform.LowerDeviceStorageAccessInfo")
.set_body_typed(LowerDeviceStorageAccessInfo);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -23,11 +23,12 @@
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/op.h>
#include <tvm/target/target.h>
#include <unordered_set>
#include "ir_util.h"
#include "../../arith/pattern_match.h"
#include "../../arith/ir_mutator_with_analyzer.h"
......@@ -39,15 +40,12 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
using IRMutatorWithAnalyzer::VisitStmt_;
using IRMutatorWithAnalyzer::VisitExpr_;
IntrinInjecter(arith::Analyzer* analyzer, std::string target)
IntrinInjecter(arith::Analyzer* analyzer, std::string target_name)
: IRMutatorWithAnalyzer(analyzer) {
std::istringstream is(target);
std::string starget;
is >> starget;
patterns_.push_back("tvm.intrin.rule." + starget + ".");
patterns_.push_back("tvm.intrin.rule." + target_name + ".");
patterns_.push_back("tvm.intrin.rule.default.");
fma_ = runtime::Registry::Get(patterns_[0] + "fma");
if (target == "stackvm") {
if (target_name == "stackvm") {
support_bitwise_op_ = false;
}
}
......@@ -280,21 +278,41 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
bool support_bitwise_op_{true};
};
Stmt LowerIntrinStmt(Stmt stmt, const std::string& target) {
Stmt LowerIntrinStmt(Stmt stmt, const std::string target_name) {
arith::Analyzer analyzer;
return IntrinInjecter(&analyzer, target)(std::move(stmt));
return IntrinInjecter(&analyzer, target_name)(std::move(stmt));
}
LoweredFunc
LowerIntrin(LoweredFunc f, const std::string& target) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = LowerIntrinStmt(n->body, target);
std::istringstream is(target);
std::string target_name;
is >> target_name;
n->body = LowerIntrinStmt(n->body, target_name);
return LoweredFunc(n);
}
// Register the api only for test purposes
TVM_REGISTER_GLOBAL("ir_pass._LowerIntrinStmt")
.set_body_typed(LowerIntrinStmt);
namespace transform {
Pass LowerIntrin() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "LowerIntrin: Require the target attribute";
arith::Analyzer analyzer;
n->body =
IntrinInjecter(&analyzer, target->target_name)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {});
}
TVM_REGISTER_GLOBAL("tir.transform.LowerIntrin")
.set_body_typed(LowerIntrin);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -30,9 +30,14 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/target/target.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/ir_pass.h>
#include <unordered_set>
#include "ir_util.h"
#include "../pass/ir_util.h"
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"
......@@ -388,5 +393,24 @@ LowerWarpMemory(LoweredFunc f, int warp_size) {
return LoweredFunc(n);
}
namespace transform {
Pass LowerWarpMemory() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "LowerWarpMemory: Require the target attribute";
n->body = WarpMemoryRewriter(target->thread_warp_size).Rewrite(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});
}
TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory")
.set_body_typed(LowerWarpMemory);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -39,7 +39,7 @@ def test_for():
f = tvm.tir.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True)
# temp adapter to convert loweredFunc to IRModule
# to test passes in the new style.
# to test passes in the new style.x
mod = tvm.testing.LoweredFuncsToIRModule([f])
mod = tvm.tir.transform.CombineContextCall()(mod)
......
......@@ -18,12 +18,15 @@ import tvm
from tvm import te
import numpy as np
def lower_intrin(stmt):
def lower_intrin(params, stmt):
"""wrapper to call transformation in stmt"""
lower_expr = isinstance(stmt, tvm.tir.PrimExpr)
stmt = tvm.tir.Evaluate(stmt) if lower_expr else stmt
stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.tir.ir_pass._LowerIntrinStmt(stmt, "llvm")
func = tvm.tir.PrimFunc(params, stmt).with_attr(
"target", tvm.target.create("llvm"))
func = tvm.tir.transform.LowerIntrin()(tvm.IRModule.from_expr(func))["main"]
stmt = func.body
return stmt.value if lower_expr else stmt.body
......@@ -70,19 +73,19 @@ def test_lower_floordiv():
y = te.var("y", dtype=dtype)
zero = tvm.tir.const(0, dtype)
# no constraints
res = lower_intrin(tvm.te.floordiv(x, y))
res = lower_intrin([x, y], tvm.te.floordiv(x, y))
check_value(res, x, y, data, lambda a, b: a // b)
# rhs >= 0
res = lower_intrin(tvm.tir.Select(y >= 0, tvm.te.floordiv(x, y), zero))
res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floordiv(x, y), zero))
check_value(res, x, y, data, lambda a, b: a // b if b > 0 else 0)
# involves max
res = lower_intrin(tvm.tir.Select(y >= 0, tvm.te.max(tvm.te.floordiv(x, y), zero), zero))
res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.max(tvm.te.floordiv(x, y), zero), zero))
check_value(res, x, y, data, lambda a, b: max(a // b, 0) if b > 0 else 0)
# lhs >= 0
res = lower_intrin(tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floordiv(x, y), zero))
res = lower_intrin([x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floordiv(x, y), zero))
check_value(res, x, y, data, lambda a, b: a // b if b > 0 and a >= 0 else 0)
# const power of two
res = lower_intrin(tvm.te.floordiv(x, tvm.tir.const(8, dtype=dtype)))
res = lower_intrin([x, y], tvm.te.floordiv(x, tvm.tir.const(8, dtype=dtype)))
check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a // b)
......@@ -93,16 +96,16 @@ def test_lower_floormod():
y = te.var("y", dtype=dtype)
zero = tvm.tir.const(0, dtype)
# no constraints
res = lower_intrin(tvm.te.floormod(x, y))
res = lower_intrin([x, y], tvm.te.floormod(x, y))
check_value(res, x, y, data, lambda a, b: a % b)
# rhs >= 0
res = lower_intrin(tvm.tir.Select(y >= 0, tvm.te.floormod(x, y), zero))
res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floormod(x, y), zero))
check_value(res, x, y, data, lambda a, b: a % b if b > 0 else 0)
# lhs >= 0
res = lower_intrin(tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floormod(x, y), zero))
res = lower_intrin([x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floormod(x, y), zero))
check_value(res, x, y, data, lambda a, b: a % b if b > 0 and a >= 0 else 0)
# const power of two
res = lower_intrin(tvm.te.floormod(x, tvm.tir.const(8, dtype=dtype)))
res = lower_intrin([x, y], tvm.te.floormod(x, tvm.tir.const(8, dtype=dtype)))
check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a % b)
......
......@@ -24,18 +24,26 @@ def test_lower_warp_mem():
s = te.create_schedule(B.op)
AA = s.cache_read(A, "warp", [B])
xo, xi = s[B].split(B.op.axis[0], 32)
xi0, xi1 = s[B].split(xi, factor=16)
xo, xi = s[B].split(B.op.axis[0], 64)
xi0, xi1 = s[B].split(xi, factor=32)
tx = te.thread_axis("threadIdx.x")
s[B].bind(xi1, tx)
s[B].bind(xo, te.thread_axis("blockIdx.x"))
s[AA].compute_at(s[B], xo)
xo, xi = s[AA].split(s[AA].op.axis[0], 16)
xo, xi = s[AA].split(s[AA].op.axis[0], 32)
s[AA].bind(xi, tx)
f = tvm.lower(s, [A, B])
fhost, fdevice = tvm.tir.ir_pass.SplitHostDevice(f)
fdevice = tvm.tir.ir_pass.LowerWarpMemory(fdevice, 16)
# temp adapter to convert loweredFunc to IRModule
# to test passes in the new style.
fname = fdevice.name
mod = tvm.testing.LoweredFuncsToIRModule([fdevice])
cuda_target = tvm.target.create("cuda")
assert cuda_target.thread_warp_size == 32
mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target))
fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["main"]
assert(fdevice.body.body.value.value == "local")
assert(fdevice.body.body.body.extents[0].value == 2)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment