Unverified Commit 4e5c5843 by Haozheng Fan Committed by GitHub

[TIR][PASS] dtype rewrite for indexing variables (#5092)

parent 4195b2e2
...@@ -115,6 +115,15 @@ class ConstIntBoundAnalyzer { ...@@ -115,6 +115,15 @@ class ConstIntBoundAnalyzer {
ConstIntBound operator()(const PrimExpr& expr); ConstIntBound operator()(const PrimExpr& expr);
/*! /*!
* \brief analyze the expr with the intermediate memorized to avoid redundant computation
* \param expr The expression of interest.
* \param bound The lookup table to store the intermediate results
* \return the result of the analysis.
*/
ConstIntBound operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound);
/*!
* \brief Update constant int bound information of var. * \brief Update constant int bound information of var.
* *
* \param var The variable of interest. * \param var The variable of interest.
......
...@@ -358,6 +358,15 @@ Stmt DecorateDeviceScope(Stmt stmt); ...@@ -358,6 +358,15 @@ Stmt DecorateDeviceScope(Stmt stmt);
Stmt HoistIfThenElse(Stmt stmt); Stmt HoistIfThenElse(Stmt stmt);
/*! /*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
* \note Run this pass after StorageFlatten.
* \param stmt The stmt to do datatype rewrite
* \param target_bits the bit of target datatype
* \return Transformed stmt.
*/
Stmt NarrowDataType(Stmt stmt, int target_bits);
/*!
* \brief Make an user callable API LoweredFunc. * \brief Make an user callable API LoweredFunc.
* *
* The main task of this function is to create code to : * The main task of this function is to create code to :
......
...@@ -87,6 +87,16 @@ TVM_DLL Pass LowerDeviceStorageAccessInfo(); ...@@ -87,6 +87,16 @@ TVM_DLL Pass LowerDeviceStorageAccessInfo();
*/ */
TVM_DLL Pass LowerWarpMemory(); TVM_DLL Pass LowerWarpMemory();
/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
*
* \note Run this pass after StorageFlatten.
*
* \return The pass.
*/
TVM_DLL Pass NarrowDataType();
} // namespace transform } // namespace transform
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
......
...@@ -159,6 +159,7 @@ def lower(sch, ...@@ -159,6 +159,7 @@ def lower(sch,
# Phase 1 # Phase 1
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds) stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers) stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.NarrowDataType(stmt, 32)
stmt = ir_pass.CanonicalSimplify(stmt) stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1: for f in lower_phase1:
stmt = f(stmt) stmt = f(stmt)
......
...@@ -370,7 +370,8 @@ class IterVar(Object, ExprOp): ...@@ -370,7 +370,8 @@ class IterVar(Object, ExprOp):
raise TypeError("dom need to be Range") raise TypeError("dom need to be Range")
name = var if var is not None else "iter" name = var if var is not None else "iter"
var = Var(name, dtype="int32") if not isinstance(var, Var) else var dtype = "int32" if dom is None else dom.extent.dtype
var = Var(name, dtype=dtype) if not isinstance(var, Var) else var
self.__init_handle_by_constructor__( self.__init_handle_by_constructor__(
_ffi_api.IterVar, dom, var, iter_type, thread_tag) _ffi_api.IterVar, dom, var, iter_type, thread_tag)
......
...@@ -76,7 +76,8 @@ class BufferVar(ObjectGeneric): ...@@ -76,7 +76,8 @@ class BufferVar(ObjectGeneric):
def __getitem__(self, index): def __getitem__(self, index):
t = DataType(self._content_type) t = DataType(self._content_type)
if t.lanes > 1: if t.lanes > 1:
index = _expr.Ramp(index * t.lanes, 1, t.lanes) base = index * t.lanes
index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
return _expr.Load(self._content_type, self._buffer_var, index) return _expr.Load(self._content_type, self._buffer_var, index)
def __setitem__(self, index, value): def __setitem__(self, index, value):
...@@ -87,7 +88,8 @@ class BufferVar(ObjectGeneric): ...@@ -87,7 +88,8 @@ class BufferVar(ObjectGeneric):
value.dtype, self._content_type)) value.dtype, self._content_type))
t = DataType(self._content_type) t = DataType(self._content_type)
if t.lanes > 1: if t.lanes > 1:
index = _expr.Ramp(index * t.lanes, 1, t.lanes) base = index * t.lanes
index = _expr.Ramp(base, const(1, base.dtype), t.lanes)
self._builder.emit(_stmt.Store(self._buffer_var, value, index)) self._builder.emit(_stmt.Store(self._buffer_var, value, index))
......
...@@ -66,3 +66,18 @@ def LowerWarpMemory(): ...@@ -66,3 +66,18 @@ def LowerWarpMemory():
The result pass The result pass
""" """
return _ffi_api.LowerWarpMemory() return _ffi_api.LowerWarpMemory()
def NarrowDataType():
"""Narrow down PrimExpr datatype in stmt to target_bits.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
Note
----
Run this pass after StorageFlatten.
"""
return _ffi_api.NarrowDataType()
...@@ -146,9 +146,30 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -146,9 +146,30 @@ class ConstIntBoundAnalyzer::Impl :
res = Intersect(res, info.bound); res = Intersect(res, info.bound);
} }
} }
if (bound_) {
const PrimExprNode* op = expr.as<PrimExprNode>();
auto val = bound_->find(op);
if (val != bound_->end()) {
CHECK(val->second->min_value == res.min_value &&
val->second->max_value == res.max_value)
<< "Detected bound for " << expr
<< "conflicts with memorization";
}
(*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
}
return res; return res;
} }
Entry VisitExpr_(const RampNode* op) final {
// op = {base + i * stride | 0 <= i < lanes}
// Entry(op) = Union(Entry(base + i * stride) | 0 <= i < lanes)
// Note that `base + i * stride` is linear w.r.t. `i`
// Entry(op) = Union(Entry(base + i * stride) | i = 0, i = lanes-1)
Entry a = VisitExpr(op->base);
Entry b = VisitExpr(op->base + (op->lanes - 1) * op->stride);
return Union(a, b);
}
Entry VisitExpr_(const CastNode* op) final { Entry VisitExpr_(const CastNode* op) final {
Entry a = VisitExpr(op->value); Entry a = VisitExpr(op->value);
Entry b = Everything(op->dtype); Entry b = Everything(op->dtype);
...@@ -340,10 +361,13 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -340,10 +361,13 @@ class ConstIntBoundAnalyzer::Impl :
} }
private: private:
friend class ConstIntBoundAnalyzer;
// internal variable map // internal variable map
std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> var_map_; std::unordered_map<Var, Entry, ObjectHash, ObjectEqual> var_map_;
// additional bound info // additional bound info
std::vector<BoundInfo> additional_info_; std::vector<BoundInfo> additional_info_;
// look up table for memorization
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound_{nullptr};
// constants: the limit value means umlimited // constants: the limit value means umlimited
// NOTE: kNegInf/kPosInf are used to represent infinity. // NOTE: kNegInf/kPosInf are used to represent infinity.
static const constexpr int64_t kNegInf = ConstIntBound::kNegInf; static const constexpr int64_t kNegInf = ConstIntBound::kNegInf;
...@@ -536,6 +560,14 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) { ...@@ -536,6 +560,14 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) {
return ConstIntBound(ret.min_value, ret.max_value); return ConstIntBound(ret.min_value, ret.max_value);
} }
ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr,
std::unordered_map<const PrimExprNode*, ConstIntBound>* bound) {
impl_->bound_ = bound;
Entry ret = impl_->VisitExpr(expr);
impl_->bound_ = nullptr;
return ConstIntBound(ret.min_value, ret.max_value);
}
void ConstIntBoundAnalyzer::Update(const Var& var, void ConstIntBoundAnalyzer::Update(const Var& var,
const ConstIntBound& info, const ConstIntBound& info,
bool override) { bool override) {
......
...@@ -943,7 +943,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { ...@@ -943,7 +943,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent); PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent);
CreateSerialFor(MakeValue(begin), CreateSerialFor(MakeValue(begin),
MakeValue(end), MakeValue(end),
ConstInt32(1), llvm::ConstantInt::getSigned(GetLLVMType(end), 1),
op->loop_var, op->loop_var,
op->body); op->body);
} }
......
...@@ -1121,7 +1121,8 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { ...@@ -1121,7 +1121,8 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
CHECK(op->for_type == ForType::Serial); CHECK(op->for_type == ForType::Serial);
} }
CreateSerialFor(MakeValue(op->min), MakeValue(op->extent), CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
ConstInt32(1), op->loop_var, op->body); llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1),
op->loop_var, op->body);
} }
......
...@@ -452,7 +452,7 @@ Buffer BufferNode::make(Var data, ...@@ -452,7 +452,7 @@ Buffer BufferNode::make(Var data,
n->buffer_type = buffer_type; n->buffer_type = buffer_type;
if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) { if (n->buffer_type == kAutoBroadcast && n->shape.size() > 0 && n->strides.empty()) {
for (size_t i = 0; i < n->shape.size(); ++i) { for (size_t i = 0; i < n->shape.size(); ++i) {
n->strides.push_back(Var("stride")); n->strides.push_back(Var("stride", n->shape[i].dtype()));
} }
} }
return Buffer(n); return Buffer(n);
......
...@@ -156,5 +156,6 @@ REGISTER_PASS(InstrumentBoundCheckers); ...@@ -156,5 +156,6 @@ REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer); REGISTER_PASS(VerifyCompactBuffer);
REGISTER_PASS(HoistIfThenElse); REGISTER_PASS(HoistIfThenElse);
REGISTER_PASS(InferFragment) REGISTER_PASS(InferFragment)
REGISTER_PASS(NarrowDataType);
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -587,7 +587,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt b ...@@ -587,7 +587,7 @@ inline Stmt LoopPartitioner::MakeFor(const Object *node, PrimExpr extent, Stmt b
// If the loop extent is 1, do not create the loop anymore // If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}});
} else { } else {
return ForNode::make(for_node->loop_var, 0, extent, return ForNode::make(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent,
for_node->for_type, for_node->device_api, body); for_node->for_type, for_node->device_api, body);
} }
} }
......
...@@ -160,7 +160,9 @@ class LoopUnroller : public StmtExprMutator { ...@@ -160,7 +160,9 @@ class LoopUnroller : public StmtExprMutator {
PrimExpr extent = tir::Simplify(op->extent); PrimExpr extent = tir::Simplify(op->extent);
const IntImmNode *v1 = extent.as<IntImmNode>(); const IntImmNode *v1 = extent.as<IntImmNode>();
int value = -1; int value = -1;
if (v1 != nullptr) { // integers that do not fit in int32_t are treated as symbolic,
// as it's impossible to unroll such large loops
if (v1 != nullptr && v1->value <= std::numeric_limits<int>::max()) {
value = static_cast<int>(v1->value); value = static_cast<int>(v1->value);
} }
return value; return value;
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
from tvm.tir import const
def lower_stmt(params, stmt, target_bits):
func = tvm.tir.PrimFunc(params, stmt).with_attr(
"target_bits", target_bits)
func = tvm.tir.transform.NarrowDataType()(tvm.IRModule.from_expr(func))["main"]
stmt = func.body
return stmt
def lower_sch(sch, args, target_bits):
binds = {}
arg_list = []
for x in args:
if isinstance(x, te.tensor.Tensor):
buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
assert x not in binds
binds[x] = buf
arg_list.append(buf)
else:
raise ValueError("args must be Tensor, Buffer or Var")
bounds = te.schedule.InferBound(sch)
stmt = te.schedule.ScheduleOps(sch, bounds)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False)
return lower_stmt(arg_list, stmt, target_bits)
def test_basic():
def check(m, n, target_bits, target_dtype):
ib = tvm.tir.ir_builder.create()
Ab = tvm.tir.decl_buffer((m, n), name='A')
A = ib.buffer_ptr(Ab)
Bb = tvm.tir.decl_buffer((m, n), name='B')
B = ib.buffer_ptr(Bb)
with ib.for_range(0, m, name='i') as i:
with ib.for_range(0, n, name='j') as j:
B[i * n + j] = A[i * n + j] + 1
stmt = ib.get()
stmt = lower_stmt([Ab, Bb], stmt, target_bits)
assert stmt.loop_var.dtype == target_dtype
assert stmt.body.loop_var.dtype == target_dtype
# const shape
# i32 -> i32
check(2, 2, 32, "int32")
check(2**16, 2**16, 32, "int32") # i32 + i32 is not promoted to i64 even if overflow
# i64 -> i32
check(const(2, dtype='int64'), const(2, dtype='int64'), 32, "int32")
check(const(2**16, dtype='int64'), const(2**16, dtype='int64'), 32, "int64")
# i32 -> i16
check(2, 2, 16, "int16")
check(2**10, 2**10, 16, "int32")
# symbolic shape
check(te.size_var(name='m', dtype='int32'), te.size_var(name='n', dtype='int32'), 32, "int32")
check(te.size_var(name='m', dtype='int64'), te.size_var(name='n', dtype='int64'), 32, "int64")
def test_thread_axis():
def check(m, n, target_bits, target_dtype):
ib = tvm.tir.ir_builder.create()
Ab = tvm.tir.decl_buffer((m, n), name='A')
A = ib.buffer_ptr(Ab)
Bb = tvm.tir.decl_buffer((m, n), name='B')
B = ib.buffer_ptr(Bb)
bx = te.thread_axis("blockIdx.x")
tx = te.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", m)
ib.scope_attr(tx, "thread_extent", n)
B[bx * n + tx] = A[bx * n + tx] + 1
stmt = ib.get()
stmt = lower_stmt([Ab, Bb], stmt, target_bits)
assert stmt.node.var.dtype == target_dtype
assert stmt.body.node.var.dtype == target_dtype
# i32 -> i32
check(2, 32,
target_bits=32, target_dtype='int32')
check(2**30, 32, # i32 + i32 is not promoted to i64 even in the case of overflow
target_bits=32, target_dtype='int32')
# i64 -> i32
check(const(2, dtype='int64'),
const(32, dtype='int64'),
target_bits=32, target_dtype='int32')
check(const(2**30, dtype='int64'),
const(32, dtype='int64'),
target_bits=32, target_dtype='int64')
# i32 -> i16
check(2, 32,
target_bits=16, target_dtype='int16')
check(2**14, 32,
target_bits=16, target_dtype='int32')
def test_multilanes():
def check(m, lanes, target_bits, target_dtype):
ib = tvm.tir.ir_builder.create()
Ab = tvm.tir.decl_buffer((m,), dtype='float32x{}'.format(lanes), name='A')
A = ib.buffer_ptr(Ab)
Bb = tvm.tir.decl_buffer((m,), dtype='float32x{}'.format(lanes), name='B')
B = ib.buffer_ptr(Bb)
with ib.for_range(0, m, name='i', dtype=m.dtype) as i:
B[i] = A[i] + 1
stmt = ib.get()
stmt = lower_stmt([Ab, Bb], stmt, target_bits)
assert stmt.loop_var.dtype == target_dtype
# i32 -> i32
check(const(2 ** 10, dtype='int32'), 2,
target_bits=32, target_dtype='int32')
check(const(2 ** 32, dtype='int32'), 2,
target_bits=32, target_dtype='int32')
# i64 -> i32
check(const(2 ** 10, dtype='int64'), 2,
target_bits=32, target_dtype='int32')
check(const(2 ** 32, dtype='int64'), 2,
target_bits=32, target_dtype='int64')
# i32 -> i16
check(const(2 ** 10, dtype='int32'), 2,
target_bits=16, target_dtype='int16')
check(const(2 ** 16, dtype='int32'), 2,
target_bits=16, target_dtype='int32')
def test_reduce():
def check(m, target_bits, target_dtype):
A = te.placeholder((m,), name='A', dtype='float32')
k = te.reduce_axis((0, m), "k")
B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name='B')
s = te.create_schedule(B.op)
stmt = lower_sch(s, [A, B], target_bits)
assert stmt.body[1].loop_var.dtype == target_dtype
# i32 -> i32
check(const(64, dtype='int32'), 32, 'int32')
# i64 -> i32
check(const(64, dtype='int64'), 32, 'int32')
# i32 -> i16
check(const(64, dtype='int32'), 16, 'int16')
check(const(2**16, dtype='int32'), 16, 'int32')
# symbolic
check(te.var('n', dtype='int32'), 32, 'int32')
check(te.var('n', dtype='int64'), 32, 'int64')
def test_slice():
def check(m, n, target_bits, target_dtype):
# The index may overflow in B, while not in A
ib = tvm.tir.ir_builder.create()
Ab = tvm.tir.decl_buffer((m, n), name='A')
A = ib.buffer_ptr(Ab)
Bb = tvm.tir.decl_buffer((m, n * 2), name='B')
B = ib.buffer_ptr(Bb)
with ib.for_range(0, m, name='i') as i:
with ib.for_range(0, n, name='j') as j:
A[i * n + j] = B[i * 2 * n + 2 * j] + 1
stmt = ib.get()
stmt = lower_stmt([Ab, Bb], stmt, target_bits)
assert stmt.loop_var.dtype == target_dtype
assert stmt.body.loop_var.dtype == target_dtype
# The maximum index is (2**15 * 2**15 - 1) * 2 <= 2**31 - 1
check(const(2**15, 'int64'), const(2**15, 'int64'),
target_bits=32, target_dtype='int32')
# The maximum index is (2**15 * 2**15 - 1 + 2**15) * 2 > 2**31 - 1
check(const(2**15, 'int64'), const((2**15 + 1), 'int64'),
target_bits=32, target_dtype='int64')
if __name__ == "__main__":
test_basic()
test_thread_axis()
test_multilanes()
test_reduce()
test_slice()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment