Unverified Commit 4e5c5843 by Haozheng Fan Committed by GitHub

[TIR][PASS] dtype rewrite for indexing variables (#5092)

parent 4195b2e2
...@@ -115,6 +115,15 @@ class ConstIntBoundAnalyzer { ...@@ -115,6 +115,15 @@ class ConstIntBoundAnalyzer {
ConstIntBound operator()(const PrimExpr& expr); ConstIntBound operator()(const PrimExpr& expr);
/*! /*!
* \brief analyze the expr with the intermediate memorized to avoid redundant computation
* \param expr The expression of interest.
* \param bound The lookup table to store the intermediate results
* \return the result of the analysis.
*/
ConstIntBound operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);
/*!
* \brief Update constant int bound information of var. * \brief Update constant int bound information of var.
* *
* \param var The variable of interest. * \param var The variable of interest.
......
...@@ -358,6 +358,15 @@ Stmt DecorateDeviceScope(Stmt stmt); ...@@ -358,6 +358,15 @@ Stmt DecorateDeviceScope(Stmt stmt);
Stmt HoistIfThenElse(Stmt stmt); Stmt HoistIfThenElse(Stmt stmt);
/*! /*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
* \note Run this pass after StorageFlatten.
* \param stmt The stmt to do datatype rewrite
* \param target_bits the bit of target datatype
* \return Transformed stmt.
*/
Stmt NarrowDataType(Stmt stmt, int target_bits);
/*!
* \brief Make an user callable API LoweredFunc. * \brief Make an user callable API LoweredFunc.
* *
* The main task of this function is to create code to : * The main task of this function is to create code to :
......
...@@ -87,6 +87,16 @@ TVM_DLL Pass LowerDeviceStorageAccessInfo(); ...@@ -87,6 +87,16 @@ TVM_DLL Pass LowerDeviceStorageAccessInfo();
*/ */
TVM_DLL Pass LowerWarpMemory(); TVM_DLL Pass LowerWarpMemory();
/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
*
* \note Run this pass after StorageFlatten.
*
* \return The pass.
*/
TVM_DLL Pass NarrowDataType();
} // namespace transform } // namespace transform
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
......
...@@ -159,6 +159,7 @@ def lower(sch, ...@@ -159,6 +159,7 @@ def lower(sch,
# Phase 1 # Phase 1
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.NarrowDataType(stmt, 32)
stmt = ir_pass.CanonicalSimplify(stmt) stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1: for f in lower_phase1:
stmt = f(stmt) stmt = f(stmt)
......
...@@ -370,7 +370,8 @@ class IterVar(Object, ExprOp): ...@@ -370,7 +370,8 @@ class IterVar(Object, ExprOp):
raise TypeError("dom need to be Range") raise TypeError("dom need to be Range")
name = var if var is not None else "iter" name = var if var is not None else "iter"
var = Var(name, dtype="int32") if not isinstance(var, Var) else var dtype = "int32" if dom is None else dom.extent.dtype
var = Var(name, dtype=dtype) if not isinstance(var, Var) else var
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_ffi_api.IterVar, dom, var, iter_type, thread_tag) _ffi_api.IterVar, dom, var, iter_type, thread_tag)
......
...@@ -76,7 +76,8 @@ class BufferVar(ObjectGeneric): ...@@ -76,7 +76,8 @@ class BufferVar(ObjectGeneric):
def __getitem__(self, index): def __getitem__(self, index):
t = DataType(self._content_type) t = DataType(self._content_type)
if t.lanes > 1: if t.lanes > 1:
index = _expr.Ramp(index * t.lanes, 1, t.lanes) base = index * t.lanes
index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
return _expr.Load(self._content_type, self._buffer_var, index) return _expr.Load(self._content_type, self._buffer_var, index)
def __setitem__(self, index, value): def __setitem__(self, index, value):
...@@ -87,7 +88,8 @@ class BufferVar(ObjectGeneric): ...@@ -87,7 +88,8 @@ class BufferVar(ObjectGeneric):
value.dtype, self._content_type)) value.dtype, self._content_type))
t = DataType(self._content_type) t = DataType(self._content_type)
if t.lanes > 1: if t.lanes > 1:
index = _expr.Ramp(index * t.lanes, 1, t.lanes) base = index * t.lanes
index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
self._builder.emit(_stmt.Store(self._buffer_var, value, index)) self._builder.emit(_stmt.Store(self._buffer_var, value, index))
......
...@@ -66,3 +66,18 @@ def LowerWarpMemory(): ...@@ -66,3 +66,18 @@ def LowerWarpMemory():
The result pass The result pass
""" """
return _ffi_api.LowerWarpMemory() return _ffi_api.LowerWarpMemory()
def NarrowDataType():
"""Narrow down PrimExpr datatype in stmt to target_bits.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
Note
----
Run this pass after StorageFlatten.
"""
return _ffi_api.NarrowDataType()
...@@ -146,9 +146,30 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -146,9 +146,30 @@ class ConstIntBoundAnalyzer::Impl :
res = Intersect(res, info.bound); res = Intersect(res, info.bound);
} }
} }
if (bound_) {
const PrimExprNode* op = expr.as<PrimExprNode>();
auto val = bound_->find(op);
if (val != bound_->end()) {
CHECK(val->second->min_value == res.min_value &&
val->second->max_value == res.max_value)
<< "Detected bound for " << expr
<< "conflicts with memorization";
}
(*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
}
return res; return res;
} }
Entry VisitExpr_(const RampNode* op) final {
// op = {base + i * stride | 0 <= i < lanes}
// Entry(op) = Union(Entry(base + i * stride) | 0 <= i < lanes)
// Note that `base + i * stride` is linear w.r.t. `i`
// Entry(op) = Union(Entry(base + i * stride) | i = 0, i = lanes-1)
Entry a = VisitExpr(op->base);
Entry b = VisitExpr(op->base + (op->lanes - 1) * op->stride);
return Union(a, b);
}
Entry VisitExpr_(const CastNode* op) final { Entry VisitExpr_(const CastNode* op) final {
Entry a = VisitExpr(op->value); Entry a = VisitExpr(op->value);
Entry b = Everything(op->dtype); Entry b = Everything(op->dtype);
...@@ -340,10 +361,13 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -340,10 +361,13 @@ class ConstIntBoundAnalyzer::Impl :
} }
private: private:
friend class ConstIntBoundAnalyzer;
// internal variable map // internal variable map
std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> var_map_; std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> var_map_;
// additional bound info // additional bound info
std::vector<BoundInfo> additional_info_; std::vector<BoundInfo> additional_info_;
// look up table for memorization
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound_{nullptr};
// constants: the limit value means umlimited // constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity. // NOTE: kNegInf/kPosInf are used to represent infinity.
static const constexpr int64_t kNegInf = ConstIntBound::kNegInf; static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
...@@ -536,6 +560,14 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) { ...@@ -536,6 +560,14 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) {
return ConstIntBound(ret.min_value, ret.max_value); return ConstIntBound(ret.min_value, ret.max_value);
} }
ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound) {
impl_->bound_ = bound;
Entry ret = impl_->VisitExpr(expr);
impl_->bound_ = nullptr;
return ConstIntBound(ret.min_value, ret.max_value);
}
void ConstIntBoundAnalyzer::Update(const Var& var, void ConstIntBoundAnalyzer::Update(const Var& var,
const ConstIntBound& info, const ConstIntBound& info,
bool override) { bool override) {
......
...@@ -943,7 +943,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { ...@@ -943,7 +943,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent); PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent);
CreateSerialFor(MakeValue(begin), CreateSerialFor(MakeValue(begin),
MakeValue(end), MakeValue(end),
ConstInt32(1), llvm::ConstantInt::getSigned(GetLLVMType(end), 1),
op->loop_var, op->loop_var,
op->body); op->body);
} }
......
...@@ -1121,7 +1121,8 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { ...@@ -1121,7 +1121,8 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
CHECK(op->for_type == ForType::Serial); CHECK(op->for_type == ForType::Serial);
} }
CreateSerialFor(MakeValue(op->min), MakeValue(op->extent), CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
ConstInt32(1), op->loop_var, op->body); llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1),
op->loop_var, op->body);
} }
......
...@@ -452,7 +452,7 @@ Buffer BufferNode::make(Var data, ...@@ -452,7 +452,7 @@ Buffer BufferNode::make(Var data,
n->buffer_type = buffer_type; n->buffer_type = buffer_type;
if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) { if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) {
for (size_t i = 0; i < n->shape.size(); ++i) { for (size_t i = 0; i < n->shape.size(); ++i) {
n->strides.push_back(Var("stride")); n->strides.push_back(Var("stride", n->shape[i].dtype()));
} }
} }
return Buffer(n); return Buffer(n);
......
...@@ -156,5 +156,6 @@ REGISTER_PASS(InstrumentBoundCheckers); ...@@ -156,5 +156,6 @@ REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(VerifyCompactBuffer);
REGISTER_PASS(HoistIfThenElse); REGISTER_PASS(HoistIfThenElse);
REGISTER_PASS(InferFragment) REGISTER_PASS(InferFragment)
REGISTER_PASS(NarrowDataType);
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -587,7 +587,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt b ...@@ -587,7 +587,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt b
// If the loop extent is 1, do not create the loop anymore // If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
} else { } else {
return ForNode::make(for_node->loop_var, 0, extent, return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent,
for_node->for_type, for_node->device_api, body); for_node->for_type, for_node->device_api, body);
} }
} }
......
...@@ -160,7 +160,9 @@ class LoopUnroller : public StmtExprMutator { ...@@ -160,7 +160,9 @@ class LoopUnroller : public StmtExprMutator {
PrimExpr extent = tir::Simplify(op->extent); PrimExpr extent = tir::Simplify(op->extent);
const IntImmNode *v1 = extent.as<IntImmNode>(); const IntImmNode *v1 = extent.as<IntImmNode>();
int value = -1; int value = -1;
if (v1 != nullptr) { // integers that do not fit in int32_t are treated as symbolic,
// as it's impossible to unroll such large loops
if (v1 != nullptr && v1->value <= std::numeric_limits<int>::max()) {
value = static_cast<int>(v1->value); value = static_cast<int>(v1->value);
} }
return value; return value;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file narrow_datatype.cc
* \brief narrow the datatype of indexing vars
*/
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
#include <tvm/runtime/registry.h>
#include "../../arith/ir_mutator_with_analyzer.h"
#include "../../arith/ir_visitor_with_analyzer.h"
namespace tvm {
namespace tir {
// This pass narrows indexing expressions (like StoreNode::Index)
// that trivially fit into i32/i16 (denoted by `target_bits_`) to
// i32/i16. Considering that i32/i16 indices may be more
// efficient on some backends (while i64 may be more efficient
// on others, like llvm), we may want this pass when i32/i16
// indices are more efficient.
//
// For Var v, we determine its dtype by examining all the PrimExpr
// that contains v, denoted by E = {e_0 = v, e_1, e_2, ..., e_k}.
// If all expressions in E fit into i32/i16, then we think v can be narrowed
// to i32/i16.
//
// To make an indexing expression i32/i16, we must make sure that every
// component of that expression is of dtype i32/i16. So besides Var, we
// rewrite the following inside an indexing expression
// - Var
// - IntImm
// - Cast
//
// Algorithm:
// - Use DataTypeVisitor to determine whether a Var can be narrowed or not.
// - Use DataTypeRewritter to rewrite the components of an indexing expression.
using arith::Analyzer;
using arith::IRMutatorWithAnalyzer;
using arith::ConstIntBound;
// Determine the result dtype for Var, IntImm and Cast,
// which will be stored in `vmap` eventually.
//
// Algorithm:
// We propogate the dtypes of all the Exprs that contain Var `var` into `vmap[var]`.
// To be more specific, if for each Expr `e` which contains `var`
// (`var` is a child node of `e` in AST), `e` fits into `target_bits_`,
// then we narrow `var` into `target_bits_`. That is,
// `vmap[var] = min(target_bits_, var.dtype.bits())`
// Otherwise, `var` is not narrowed, that is, `vmap[var] = var.dtype.bits()`
class DataTypeVisitor final : public StmtExprVisitor {
public:
explicit DataTypeVisitor(int target_bits)
: bits_(target_bits), target_bits_(target_bits) {}
void VisitExpr(const PrimExpr& e) {
if (e.dtype().is_int()) {
int bits = max_bits_;
const PrimExprNode* op = e.as<PrimExprNode>();
if (bound_.find(op) == bound_.end()) {
analyzer_.const_int_bound(e, &bound_);
}
ConstIntBound bound = bound_[op];
int64_t ubound = Downcast<IntImm>(max_value(DataType::Int(target_bits_)))->value;
int64_t lbound = Downcast<IntImm>(min_value(DataType::Int(target_bits_)))->value;
if (e.dtype().bits() <= target_bits_ ||
(bound->max_value <= ubound && bound->min_value >= lbound)) {
bits = target_bits_;
}
int tmp = bits > bits_ ? bits : bits_;
std::swap(bits_, tmp);
StmtExprVisitor::VisitExpr(e);
std::swap(bits_, tmp);
} else {
StmtExprVisitor::VisitExpr(e);
}
}
void VisitStmt_(const ForNode* op) {
analyzer_.Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
vextent_[op->loop_var.as<VarNode>()] = op->extent.dtype();
return StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
CHECK_NE(iv->thread_tag.length(), 0U);
analyzer_.Bind(iv->var,
Range::make_by_min_extent(0, op->value));
vextent_[iv->var.as<VarNode>()] = op->value.dtype();
StmtExprVisitor::VisitStmt_(op);
} else {
StmtExprVisitor::VisitStmt_(op);
}
}
void VisitExpr_(const ReduceNode* op) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
analyzer_.Bind(iv->var, iv->dom);
vextent_[iv->var.as<VarNode>()] = iv->dom->extent.dtype();
}
// Recursively call simplification when necessary.
StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const VarNode* op) {
if (vextent_.find(op) != vextent_.end()) {
// We only narrow and never promote, so the result dtype
// is upperbounded by its original dtype before rewrite.
int bits = std::min(vextent_[op].bits(), bits_);
if (vmap.find(op) == vmap.end()) {
vmap[op] = op->dtype.with_bits(bits);
} else {
// We take maximum bits for all the possible Expr where a var occurs
vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
}
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const IntImmNode* op) {
if (op->dtype.is_int()) {
// We only narrow and never promote, so the result dtype
// is upperbounded by its original dtype before rewrite.
int bits = std::min(op->dtype.bits(), bits_);
if (vmap.find(op) == vmap.end()) {
vmap[op] = op->dtype.with_bits(bits);
} else {
vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
}
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const CastNode* op) {
if (op->dtype.is_int()) {
// We only narrow and never promote, so the result dtype
// is upperbounded by its original dtype before rewrite.
int bits = std::min(op->dtype.bits(), bits_);
if (vmap.find(op) == vmap.end()) {
vmap[op] = op->dtype.with_bits(bits);
} else {
vmap[op] = op->dtype.with_bits(std::max(vmap[op].bits(), bits));
}
}
StmtExprVisitor::VisitExpr_(op);
}
// the narrowed datatype of Var and IntImm
std::unordered_map<const PrimExprNode*, DataType> vmap;
protected:
// internal analyzer
arith::Analyzer analyzer_;
private:
// the maximum possible bits, which serves as an init value
static constexpr const int max_bits_ = 64;
// the maximum possible bit of the current expression's return dtype
int bits_;
// the target bits
int target_bits_;
// the extent of vars to be rewritten
std::unordered_map<const VarNode*, DataType> vextent_;
// the memorized bound generated by ConstIntBoundAnalyzer
std::unordered_map<const PrimExprNode*, ConstIntBound> bound_;
};
class DataTypeRewriter : public StmtExprMutator {
public:
explicit DataTypeRewriter(int target_bits): visitor_(target_bits) {}
Stmt operator()(Stmt s) {
visitor_(s);
for (auto i = visitor_.vmap.begin(), last = visitor_.vmap.end(); i != last;) {
PrimExpr e = GetRef<PrimExpr>(i->first);
if (e.dtype() == i->second) {
i = visitor_.vmap.erase(i);
} else {
++i;
}
}
return VisitStmt(s);
}
Stmt VisitStmt_(const StoreNode* op) final {
PrimExpr value = this->VisitExpr(op->value);
is_index_ = true;
PrimExpr index = this->VisitExpr(op->index);
is_index_ = false;
Stmt s = StoreNode::make(op->buffer_var,
op->value,
index,
op->predicate);
return StmtExprMutator::VisitStmt_(s.as<StoreNode>());
}
Stmt VisitStmt_(const ForNode* op) final {
Stmt s = StmtExprMutator::VisitStmt_(op);
op = s.as<ForNode>();
CHECK(op != nullptr)
<< "Expected type to be ForNode"
<< ", but get " << s->GetTypeKey();
PrimExpr e = VisitExpr(op->loop_var);
Var var = Downcast<Var>(e);
return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent),
op->for_type, op->device_api, op->body);
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
Stmt s = StmtExprMutator::VisitStmt_(op);
op = s.as<AttrStmtNode>();
CHECK(op != nullptr)
<< "Expected type to be AttrStmtNode"
<< ", but get " << s->GetTypeKey();
const IterVarNode* iv = op->node.as<IterVarNode>();
CHECK(iv != nullptr)
<< "Expected type to be IterVarNode"
<< ", but get " << op->node->GetTypeKey();
PrimExpr e = VisitExpr(iv->var);
Var var = Downcast<Var>(e);
if (ivmap_.find(iv) == ivmap_.end()) {
ivmap_[iv] = IterVarNode::make(iv->dom, var, iv->iter_type, iv->thread_tag);
}
return AttrStmtNode::make(
ivmap_[iv],
op->attr_key,
cast(var.dtype(), op->value),
op->body);
}
return StmtExprMutator::VisitStmt_(op);
}
PrimExpr VisitExpr_(const VarNode* op) final {
if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
if (vmap_.find(op) == vmap_.end()) {
vmap_[op] = Var(op->name_hint, visitor_.vmap[op]);
}
return vmap_[op];
}
return StmtExprMutator::VisitExpr_(op);
}
PrimExpr VisitExpr_(const SizeVarNode* op) final {
if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
if (vmap_.find(op) == vmap_.end()) {
vmap_[op] = SizeVar(op->name_hint, visitor_.vmap[op]);
}
return vmap_[op];
}
return StmtExprMutator::VisitExpr_(op);
}
PrimExpr VisitExpr_(const LoadNode* op) final {
is_index_ = true;
PrimExpr index = this->VisitExpr(op->index);
is_index_ = false;
PrimExpr e = LoadNode::make(op->dtype, op->buffer_var, index, op->predicate);
return StmtExprMutator::VisitExpr_(e.as<LoadNode>());
}
PrimExpr VisitExpr_(const IntImmNode* op) final {
if (is_index_) {
if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
return IntImm(visitor_.vmap[op], op->value);
}
}
return StmtExprMutator::VisitExpr_(op);
}
PrimExpr VisitExpr_(const CastNode* op) final {
if (is_index_ && visitor_.vmap.find(op) != visitor_.vmap.end()) {
PrimExpr e = StmtExprMutator::VisitExpr_(op);
const CastNode* new_op = e.as<CastNode>();
CHECK(new_op != nullptr)
<< "Expected type to be CastNode"
<< ", but get " << e->GetTypeKey();
return CastNode::make(visitor_.vmap[op], new_op->value);
}
return StmtExprMutator::VisitExpr_(op);
}
PrimExpr VisitExpr_(const AddNode* op) final;
PrimExpr VisitExpr_(const SubNode* op) final;
PrimExpr VisitExpr_(const MulNode* op) final;
PrimExpr VisitExpr_(const DivNode* op) final;
PrimExpr VisitExpr_(const ModNode* op) final;
PrimExpr VisitExpr_(const FloorDivNode* op) final;
PrimExpr VisitExpr_(const FloorModNode* op) final;
PrimExpr VisitExpr_(const MinNode* op) final;
PrimExpr VisitExpr_(const MaxNode* op) final;
PrimExpr VisitExpr_(const EQNode* op) final;
PrimExpr VisitExpr_(const NENode* op) final;
PrimExpr VisitExpr_(const LTNode* op) final;
PrimExpr VisitExpr_(const LENode* op) final;
PrimExpr VisitExpr_(const GTNode* op) final;
PrimExpr VisitExpr_(const GENode* op) final;
PrimExpr VisitExpr_(const CallNode* op) final;
private:
// the internal visitor to deduce the narrowed dtype
DataTypeVisitor visitor_;
// a map from Var before rewrite to that after rewrite,
// ensures one old Var maps to exactly one new Var
std::unordered_map<const VarNode*, Var> vmap_;
// a map from IterVar before rewrite to that after rewrite,
// ensures one old IterVar maps to exactly one new IterVar
std::unordered_map<const IterVarNode*, IterVar> ivmap_;
// indicator of LoadNode::index and StoreNode::index
bool is_index_{false};
};
#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \
PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) { \
PrimExpr a = this->VisitExpr(op->a); \
PrimExpr b = this->VisitExpr(op->b); \
if (a.same_as(op->a) && \
b.same_as(op->b)) { \
return GetRef<PrimExpr>(op); \
} else { \
return FUNC(a, b); \
} \
}
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator <)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator >)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=)
PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) {
PrimExpr e = StmtExprMutator::VisitExpr_(op);
op = e.as<CallNode>();
CHECK(op != nullptr)
<< "Expected type to be CallNode"
<< ", but get " << e->GetTypeKey();
if (op->call_type == CallNode::PureIntrinsic) {
if (op->name == intrinsic::tvm_if_then_else) {
return if_then_else(op->args[0], op->args[1], op->args[2]);
} else if (op->name == CallNode::shift_right) {
return op->args[0] >> op->args[1];
} else if (op->name == CallNode::shift_left) {
return op->args[0] << op->args[1];
} else if (op->name == CallNode::bitwise_and) {
return op->args[0] & op->args[1];
} else if (op->name == CallNode::bitwise_or) {
return op->args[0] | op->args[1];
} else if (op->name == CallNode::bitwise_xor) {
return op->args[0] ^ op->args[1];
} else if (op->name == "pow") {
return pow(op->args[0], op->args[1]);
}
}
return e;
}
Stmt NarrowDataType(Stmt stmt, int target_bits) {
return DataTypeRewriter(target_bits)(stmt);
}
namespace transform {
Pass NarrowDataType() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
IntImm target_bits = f->GetAttr<IntImm>("target_bits");
CHECK(target_bits.defined())
<< "NarrowDataType: Require the target_bits";
n->body = DataTypeRewriter(target_bits->value)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(
pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {});
}
TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType")
.set_body_typed(NarrowDataType);
} // namespace transform
} // namespace tir
} // namespace tvm
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
from tvm.tir import const
def lower_stmt(params, stmt, target_bits):
func = tvm.tir.PrimFunc(params, stmt).with_attr(
"target_bits", target_bits)
func = tvm.tir.transform.NarrowDataType()(tvm.IRModule.from_expr(func))["main"]
stmt = func.body
return stmt
def lower_sch(sch, args, target_bits):
binds = {}
arg_list = []
for x in args:
if isinstance(x, te.tensor.Tensor):
buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
assert x not in binds
binds[x] = buf
arg_list.append(buf)
else:
raise ValueError("args must be Tensor, Buffer or Var")
bounds = te.schedule.InferBound(sch)
stmt = te.schedule.ScheduleOps(sch, bounds)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False)
return lower_stmt(arg_list, stmt, target_bits)
def test_basic():
def check(m, n, target_bits, target_dtype):
ib = tvm.tir.ir_builder.create()
Ab = tvm.tir.decl_buffer((m, n), name='A')
A = ib.buffer_ptr(Ab)
Bb = tvm.tir.decl_buffer((m, n), name='B')
B = ib.buffer_ptr(Bb)
with ib.for_range(0, m, name='i') as i:
with ib.for_range(0, n, name='j') as j:
B[i * n + j] = A[i * n + j] + 1
stmt = ib.get()
stmt = lower_stmt([Ab, Bb], stmt, target_bits)
assert stmt.loop_var.dtype == target_dtype
assert stmt.body.loop_var.dtype == target_dtype
# const shape
# i32 -> i32
check(2, 2, 32, "int32")
check(2**16, 2**16, 32, "int32") # i32 + i32 is not promoted to i64 even if overflow
# i64 -> i32
check(const(2, dtype='int64'), const(2, dtype='int64'), 32, "int32")
check(const(2**16, dtype='int64'), const(2**16, dtype='int64'), 32, "int64")
# i32 -> i16
check(2, 2, 16, "int16")
check(2**10, 2**10, 16, "int32")
# symbolic shape
check(te.size_var(name='m', dtype='int32'), te.size_var(name='n', dtype='int32'), 32, "int32")
check(te.size_var(name='m', dtype='int64'), te.size_var(name='n', dtype='int64'), 32, "int64")
def test_thread_axis():
def check(m, n, target_bits, target_dtype):
ib = tvm.tir.ir_builder.create()
Ab = tvm.tir.decl_buffer((m, n), name='A')
A = ib.buffer_ptr(Ab)
Bb = tvm.tir.decl_buffer((m, n), name='B')
B = ib.buffer_ptr(Bb)
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", m)
ib.scope_attr(tx, "thread_extent", n)
B[bx * n + tx] = A[bx * n + tx] + 1
stmt = ib.get()
stmt = lower_stmt([Ab, Bb], stmt, target_bits)
assert stmt.node.var.dtype == target_dtype
assert stmt.body.node.var.dtype == target_dtype
# i32 -> i32
check(2, 32,
target_bits=32, target_dtype='int32')
check(2**30, 32, # i32 + i32 is not promoted to i64 even in the case of overflow
target_bits=32, target_dtype='int32')
# i64 -> i32
check(const(2, dtype='int64'),
const(32, dtype='int64'),
target_bits=32, target_dtype='int32')
check(const(2**30, dtype='int64'),
const(32, dtype='int64'),
target_bits=32, target_dtype='int64')
# i32 -> i16
check(2, 32,
target_bits=16, target_dtype='int16')
check(2**14, 32,
target_bits=16, target_dtype='int32')
def test_multilanes():
def check(m, lanes, target_bits, target_dtype):
ib = tvm.tir.ir_builder.create()
Ab = tvm.tir.decl_buffer((m,), dtype='float32x{}'.format(lanes), name='A')
A = ib.buffer_ptr(Ab)
Bb = tvm.tir.decl_buffer((m,), dtype='float32x{}'.format(lanes), name='B')
B = ib.buffer_ptr(Bb)
with ib.for_range(0, m, name='i', dtype=m.dtype) as i:
B[i] = A[i] + 1
stmt = ib.get()
stmt = lower_stmt([Ab, Bb], stmt, target_bits)
assert stmt.loop_var.dtype == target_dtype
# i32 -> i32
check(const(2 ** 10, dtype='int32'), 2,
target_bits=32, target_dtype='int32')
check(const(2 ** 32, dtype='int32'), 2,
target_bits=32, target_dtype='int32')
# i64 -> i32
check(const(2 ** 10, dtype='int64'), 2,
target_bits=32, target_dtype='int32')
check(const(2 ** 32, dtype='int64'), 2,
target_bits=32, target_dtype='int64')
# i32 -> i16
check(const(2 ** 10, dtype='int32'), 2,
target_bits=16, target_dtype='int16')
check(const(2 ** 16, dtype='int32'), 2,
target_bits=16, target_dtype='int32')
def test_reduce():
def check(m, target_bits, target_dtype):
A = te.placeholder((m,), name='A', dtype='float32')
k = te.reduce_axis((0, m), "k")
B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name='B')
s = te.create_schedule(B.op)
stmt = lower_sch(s, [A, B], target_bits)
assert stmt.body[1].loop_var.dtype == target_dtype
# i32 -> i32
check(const(64, dtype='int32'), 32, 'int32')
# i64 -> i32
check(const(64, dtype='int64'), 32, 'int32')
# i32 -> i16
check(const(64, dtype='int32'), 16, 'int16')
check(const(2**16, dtype='int32'), 16, 'int32')
# symbolic
check(te.var('n', dtype='int32'), 32, 'int32')
check(te.var('n', dtype='int64'), 32, 'int64')
def test_slice():
def check(m, n, target_bits, target_dtype):
# The index may overflow in B, while not in A
ib = tvm.tir.ir_builder.create()
Ab = tvm.tir.decl_buffer((m, n), name='A')
A = ib.buffer_ptr(Ab)
Bb = tvm.tir.decl_buffer((m, n * 2), name='B')
B = ib.buffer_ptr(Bb)
with ib.for_range(0, m, name='i') as i:
with ib.for_range(0, n, name='j') as j:
A[i * n + j] = B[i * 2 * n + 2 * j] + 1
stmt = ib.get()
stmt = lower_stmt([Ab, Bb], stmt, target_bits)
assert stmt.loop_var.dtype == target_dtype
assert stmt.body.loop_var.dtype == target_dtype
# The maximum index is (2**15 * 2**15 - 1) * 2 <= 2**31 - 1
check(const(2**15, 'int64'), const(2**15, 'int64'),
target_bits=32, target_dtype='int32')
# The maximum index is (2**15 * 2**15 - 1 + 2**15) * 2 > 2**31 - 1
check(const(2**15, 'int64'), const((2**15 + 1), 'int64'),
target_bits=32, target_dtype='int64')
if __name__ == "__main__":
test_basic()
test_thread_axis()
test_multilanes()
test_reduce()
test_slice()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment