Unverified Commit 4cebb1c7 by Tianqi Chen Committed by GitHub

[ARITH] Remove legacy const pattern functions (#5387)

parent d9cecdf5
...@@ -56,27 +56,6 @@ template<typename Op> ...@@ -56,27 +56,6 @@ template<typename Op>
inline PrimExpr ComputeReduce( inline PrimExpr ComputeReduce(
const Array<PrimExpr>& values, PrimExpr empty_value); const Array<PrimExpr>& values, PrimExpr empty_value);
inline bool GetConst(PrimExpr e, int64_t* out) {
if (e.dtype().is_vector()) return false;
const int64_t* v = tir::as_const_int(e);
if (v) {
*out = *v; return true;
} else {
return false;
}
}
// get a small constant int
inline bool GetConstInt(PrimExpr e, int* out) {
int64_t v1 = 0;
if (GetConst(e, &v1)) {
if (v1 > static_cast<int64_t>(
std::numeric_limits<int>::max())) return false;
*out = static_cast<int>(v1); return true;
}
return false;
}
template<> template<>
inline PrimExpr Compute<tir::AddNode>(PrimExpr a, PrimExpr b) { inline PrimExpr Compute<tir::AddNode>(PrimExpr a, PrimExpr b) {
return a + b; return a + b;
......
...@@ -574,6 +574,17 @@ ramp(const Pattern<TBase>& base, ...@@ -574,6 +574,17 @@ ramp(const Pattern<TBase>& base,
base.derived(), stride.derived(), lanes.derived()); base.derived(), stride.derived(), lanes.derived());
} }
template<typename TBase>
inline PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>>
ramp(const Pattern<TBase>& base,
int stride,
int lanes) {
return PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>>(
base.derived(),
PConstWithTypeLike<TBase>(base.derived(), stride),
PConst<int>(lanes));
}
/*! /*!
* \brief Pattern broadcast expression. * \brief Pattern broadcast expression.
* \tparam TA The pattern type of the value. * \tparam TA The pattern type of the value.
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "codegen_llvm.h" #include "codegen_llvm.h"
#include "codegen_cpu.h" #include "codegen_cpu.h"
#include "../../arith/pattern_match.h"
#include "../build_common.h" #include "../build_common.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -363,16 +364,16 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, ...@@ -363,16 +364,16 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
md_builder_->createTBAAStructTagNode(meta, meta, 0)); md_builder_->createTBAAStructTagNode(meta, meta, 0));
return; return;
} }
int base = 0, width = 0;
int64_t base = 0, width = 0;
arith::PVar<IntImm> pbase, pstride;
arith::PVar<int> planes;
// create meta-data for alias analysis // create meta-data for alias analysis
// Use a group of binary tree ranges of memory banks. // Use a group of binary tree ranges of memory banks.
if (index.defined()) { if (index.defined()) {
const RampNode* ramp = index.as<RampNode>(); if (arith::ramp(pbase, pstride, planes).Match(index)) {
if (ramp) { base = pbase.Eval()->value;
int base, stride; int64_t xwith = planes.Eval() * pstride.Eval()->value;
if (arith::GetConstInt(ramp->base, &base) &&
arith::GetConstInt(ramp->stride, &stride)) {
int xwith = ramp->lanes * stride;
width = 1; width = 1;
while (width < xwith) { while (width < xwith) {
width *= 2; width *= 2;
...@@ -381,9 +382,9 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, ...@@ -381,9 +382,9 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
base -= base % width; base -= base % width;
width *= 2; width *= 2;
} }
} } else if (auto* ptr = index.as<tir::IntImmNode>()) {
} else { width = 1;
if (arith::GetConstInt(index, &base)) width = 1; base = ptr->value;
} }
} }
llvm::MDNode* meta = md_tbaa_root_; llvm::MDNode* meta = md_tbaa_root_;
...@@ -394,8 +395,8 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, ...@@ -394,8 +395,8 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta); meta = md_builder_->createTBAAScalarTypeNode(buffer_type.str(), meta);
// create a tree-shape access structure. // create a tree-shape access structure.
if (width != 0) { if (width != 0) {
for (int w = 1024; w >= width; w /= 2) { for (int64_t w = 1024; w >= width; w /= 2) {
int b = (base / w) * w; int64_t b = (base / w) * w;
std::stringstream os; std::stringstream os;
os << buffer << ".w" << w << ".b" << b; os << buffer << ".w" << w << ".b" << b;
meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta); meta = md_builder_->createTBAAScalarTypeNode(os.str(), meta);
......
...@@ -23,8 +23,8 @@ ...@@ -23,8 +23,8 @@
#include <iomanip> #include <iomanip>
#include <cctype> #include <cctype>
#include "codegen_c.h" #include "codegen_c.h"
#include "../../arith/pattern_match.h"
#include "../../arith/compute_expr.h" #include "../../arith/compute_expr.h"
#include "../../tir/pass/ir_util.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -198,8 +198,8 @@ std::string CodeGenC::GetBufferRef( ...@@ -198,8 +198,8 @@ std::string CodeGenC::GetBufferRef(
// optimize for case where it is in register, // optimize for case where it is in register,
if (HandleTypeMatch(buffer, t) && !is_vol) { if (HandleTypeMatch(buffer, t) && !is_vol) {
// optimize for constant access // optimize for constant access
int offset; if (auto* ptr = index.as<tir::IntImmNode>()) {
if (arith::GetConstInt(index, &offset)) { int64_t offset = ptr->value;
CHECK_EQ(offset % t.lanes(), 0) CHECK_EQ(offset % t.lanes(), 0)
<< "Find unaligned vector load to a vector type"; << "Find unaligned vector load to a vector type";
os << vid << '[' << (offset / t.lanes()) << ']'; os << vid << '[' << (offset / t.lanes()) << ']';
...@@ -663,9 +663,10 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*) ...@@ -663,9 +663,10 @@ void CodeGenC::VisitExpr_(const LoadNode* op, std::ostream& os) { // NOLINT(*)
} else { } else {
CHECK(is_one(op->predicate)) CHECK(is_one(op->predicate))
<< "predicated load is not supported"; << "predicated load is not supported";
PrimExpr base;
if (GetRamp1Base(op->index, op->dtype.lanes(), &base)) { arith::PVar<PrimExpr> base;
std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base); if (arith::ramp(base, 1, op->dtype.lanes()).Match(op->index)) {
std::string ref = GetVecLoad(op->dtype, op->buffer_var.get(), base.Eval());
HandleVolatileLoads(ref, op, os); HandleVolatileLoads(ref, op, os);
} else { } else {
std::ostringstream svalue_expr; std::ostringstream svalue_expr;
...@@ -708,10 +709,10 @@ void CodeGenC::VisitStmt_(const StoreNode* op) { ...@@ -708,10 +709,10 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
} else { } else {
CHECK(is_one(op->predicate)) CHECK(is_one(op->predicate))
<< "Predicated store is not supported"; << "Predicated store is not supported";
PrimExpr base; arith::PVar<PrimExpr> base;
if (GetRamp1Base(op->index, t.lanes(), &base)) { if (arith::ramp(base, 1, t.lanes()).Match(op->index)) {
std::string value = this->PrintExpr(op->value); std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer_var.get(), t, base, value); this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value);
} else { } else {
// The assignment below introduces side-effect, and the resulting value cannot // The assignment below introduces side-effect, and the resulting value cannot
// be reused across multiple expression, thus a new scope is needed // be reused across multiple expression, thus a new scope is needed
......
...@@ -103,11 +103,11 @@ spirv::Value CodeGenSPIRV::GetThreadIndex( ...@@ -103,11 +103,11 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(
spirv::Value v; spirv::Value v;
if (ts.rank == 1) { if (ts.rank == 1) {
v = builder_->GetLocalID(ts.dim_index); v = builder_->GetLocalID(ts.dim_index);
int size = 0; auto* sizeptr = extent.as<tir::IntImmNode>();
CHECK(arith::GetConstInt(extent, &size)) CHECK(sizeptr)
<< "SPIRV only allows constant thread group size " << " get " << extent; << "SPIRV only allows constant thread group size " << " get " << extent;
CHECK_LT(ts.dim_index, 3); CHECK_LT(ts.dim_index, 3);
workgroup_size_[ts.dim_index] = static_cast<uint32_t>(size); workgroup_size_[ts.dim_index] = static_cast<uint32_t>(sizeptr->value);
} else { } else {
v = builder_->GetWorkgroupID(ts.dim_index); v = builder_->GetWorkgroupID(ts.dim_index);
} }
......
...@@ -291,9 +291,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, ...@@ -291,9 +291,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer,
} }
// Byte_offset field. // Byte_offset field.
int data_bytes = GetVectorBytes(buffer->dtype); int data_bytes = GetVectorBytes(buffer->dtype);
int64_t const_offset;
if (arith::GetConst(buffer->elem_offset, &const_offset)) { if (const auto* const_offset = buffer->elem_offset.as<IntImmNode>()) {
Bind_(make_const(DataType::UInt(64), const_offset * data_bytes), Bind_(make_const(DataType::UInt(64), const_offset->value * data_bytes),
TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset), TVMArrayGet(DataType::UInt(64), handle, intrinsic::kArrByteOffset),
arg_name + ".byte_offset", true); arg_name + ".byte_offset", true);
} else { } else {
......
...@@ -174,22 +174,6 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) { ...@@ -174,22 +174,6 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) {
return align; return align;
} }
/*!
* \brief Pattern match index to Ramp with stride=1
* This is a common pattern in continuous memory load.
* \param index The index formula
* \param lanes number of lanes in the ramp
* \param base The result base.
* \return true if pattern match success and store the base to base.
*/
inline bool GetRamp1Base(PrimExpr index, int lanes, PrimExpr *base) {
const RampNode* r = index.as<RampNode>();
if (!r) return false;
if (!is_one(r->stride)) return false;
CHECK_EQ(r->lanes, lanes);
*base = r->base;
return true;
}
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
#endif // TVM_TIR_PASS_IR_UTIL_H_ #endif // TVM_TIR_PASS_IR_UTIL_H_
...@@ -57,15 +57,15 @@ class ExprTouched final : public StmtExprVisitor { ...@@ -57,15 +57,15 @@ class ExprTouched final : public StmtExprVisitor {
} }
void VisitExpr_(const CallNode *op) final { void VisitExpr_(const CallNode *op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
int rw_mask = 0; const auto* rw_mask = op->args[4].as<IntImmNode>();
CHECK(arith::GetConstInt(op->args[4], &rw_mask));
const VarNode* buffer_var = op->args[1].as<VarNode>(); const VarNode* buffer_var = op->args[1].as<VarNode>();
CHECK(buffer_var); CHECK(buffer_var);
CHECK(rw_mask);
// read // read
if (rw_mask & 1) { if (rw_mask->value & 1) {
HandleUseVar(buffer_var); HandleUseVar(buffer_var);
} }
if (rw_mask & 2) { if (rw_mask->value & 2) {
HandleWriteVar(buffer_var); HandleWriteVar(buffer_var);
} }
this->VisitExpr(op->args[2]); this->VisitExpr(op->args[2]);
......
...@@ -163,8 +163,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { ...@@ -163,8 +163,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
CHECK_GE(e.scope.dim_index, 0) CHECK_GE(e.scope.dim_index, 0)
<< "vthread do not work with cross thread reduction"; << "vthread do not work with cross thread reduction";
if (e.scope.rank == 1) { if (e.scope.rank == 1) {
CHECK(arith::GetConstInt(attr->value, &(e.extent))) const auto* ptr = attr->value.as<IntImmNode>();
CHECK(ptr)
<< "Need constant extent for reduce set " << iv; << "Need constant extent for reduce set " << iv;
e.extent = static_cast<int>(ptr->value);
if (reduce_set.count(iv->var.get())) { if (reduce_set.count(iv->var.get())) {
vred.push_back(e); vred.push_back(e);
++nmatch; ++nmatch;
......
...@@ -30,7 +30,6 @@ ...@@ -30,7 +30,6 @@
#include <unordered_set> #include <unordered_set>
#include "../pass/ir_util.h" #include "../pass/ir_util.h"
#include "../../arith/compute_expr.h"
namespace tvm { namespace tvm {
namespace tir { namespace tir {
...@@ -94,11 +93,10 @@ class BuiltinLower : public StmtExprMutator { ...@@ -94,11 +93,10 @@ class BuiltinLower : public StmtExprMutator {
Stmt stmt = StmtExprMutator::VisitStmt_(op); Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>(); op = stmt.as<AllocateNode>();
// Get constant allocation bound. // Get constant allocation bound.
int64_t dev_type;
int64_t nbytes = GetVectorBytes(op->dtype); int64_t nbytes = GetVectorBytes(op->dtype);
if (device_type_.defined()) { if (device_type_.defined()) {
if (arith::GetConst(device_type_, &dev_type)) { if (const auto* dev_type = device_type_.as<IntImmNode>()) {
if (dev_type == kDLCPU) { if (dev_type->value == kDLCPU) {
int32_t constant_size = op->constant_allocation_size(); int32_t constant_size = op->constant_allocation_size();
if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) { if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) {
return stmt; return stmt;
......
...@@ -37,7 +37,7 @@ ...@@ -37,7 +37,7 @@
#include <unordered_set> #include <unordered_set>
#include "../pass/ir_util.h" #include "../../arith/pattern_match.h"
#include "../../arith/compute_expr.h" #include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h" #include "../../runtime/thread_storage_scope.h"
...@@ -121,11 +121,11 @@ class WarpStoreCoeffFinder : private StmtVisitor { ...@@ -121,11 +121,11 @@ class WarpStoreCoeffFinder : private StmtVisitor {
if (op->value.dtype().lanes() == 1) { if (op->value.dtype().lanes() == 1) {
UpdatePattern(op->index); UpdatePattern(op->index);
} else { } else {
PrimExpr base; arith::PVar<PrimExpr> base;
CHECK(GetRamp1Base(op->index, op->value.dtype().lanes(), &base)) CHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(op->index))
<< "LowerWarpMemory failed due to store index=" << op->index << "LowerWarpMemory failed due to store index=" << op->index
<< ", can only handle continuous store"; << ", can only handle continuous store";
UpdatePattern(base); UpdatePattern(base.Eval());
} }
} else { } else {
StmtVisitor::VisitStmt_(op); StmtVisitor::VisitStmt_(op);
...@@ -137,19 +137,18 @@ class WarpStoreCoeffFinder : private StmtVisitor { ...@@ -137,19 +137,18 @@ class WarpStoreCoeffFinder : private StmtVisitor {
arith::DetectLinearEquation(index, {warp_index_}); arith::DetectLinearEquation(index, {warp_index_});
CHECK_EQ(m.size(), 2U) CHECK_EQ(m.size(), 2U)
<< "LowerWarpMemory failed due to store index=" << index; << "LowerWarpMemory failed due to store index=" << index;
int coeff = 0;
PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]); PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]);
const auto* mcoeff_as_int = mcoeff.as<IntImmNode>();
CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0) CHECK(mcoeff_as_int && mcoeff_as_int->value > 0)
<< "LowerWarpMemory failed due to store index=" << index << "LowerWarpMemory failed due to store index=" << index
<< ", require positive constant coefficient on warp index " << warp_index_ << ", require positive constant coefficient on warp index " << warp_index_
<< " but get " << mcoeff; << " but get " << mcoeff;
if (warp_coeff_ != 0) { if (warp_coeff_ != 0) {
CHECK_EQ(warp_coeff_, coeff) CHECK_EQ(warp_coeff_, mcoeff_as_int->value)
<< "LowerWarpMemory failed due to two different store coefficient to warp index"; << "LowerWarpMemory failed due to two different store coefficient to warp index";
} else { } else {
warp_coeff_ = coeff; warp_coeff_ = mcoeff_as_int->value;
} }
} }
...@@ -158,7 +157,7 @@ class WarpStoreCoeffFinder : private StmtVisitor { ...@@ -158,7 +157,7 @@ class WarpStoreCoeffFinder : private StmtVisitor {
// the warp index // the warp index
Var warp_index_; Var warp_index_;
// the coefficient // the coefficient
int warp_coeff_{0}; int64_t warp_coeff_{0};
// analyzer. // analyzer.
arith::Analyzer* analyzer_; arith::Analyzer* analyzer_;
}; };
...@@ -184,10 +183,10 @@ class WarpIndexFinder : private StmtVisitor { ...@@ -184,10 +183,10 @@ class WarpIndexFinder : private StmtVisitor {
if (op->attr_key == attr::thread_extent) { if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") { if (iv->thread_tag == "threadIdx.x") {
int value = 0; auto* value_as_int = op->value.as<IntImmNode>();
CHECK(arith::GetConstInt(op->value, &value) && CHECK(value_as_int &&
value <= warp_size_ && value_as_int->value <= warp_size_ &&
warp_size_ % value == 0) warp_size_ % value_as_int->value == 0)
<< "Expect threadIdx.x 's size to be no larger than, and a factor of" << "Expect threadIdx.x 's size to be no larger than, and a factor of"
<< " warp size(" << warp_size_ << ")" << " to enable warp memory" << " warp size(" << warp_size_ << ")" << " to enable warp memory"
<< " but get " << op->value << " instead"; << " but get " << op->value << " instead";
...@@ -198,7 +197,7 @@ class WarpIndexFinder : private StmtVisitor { ...@@ -198,7 +197,7 @@ class WarpIndexFinder : private StmtVisitor {
<< "Please create it using thread_axis once and reuse the axis " << "Please create it using thread_axis once and reuse the axis "
<< "across multiple binds in the same kernel"; << "across multiple binds in the same kernel";
} else { } else {
width_ = value; width_ = value_as_int->value;
warp_index_ = iv; warp_index_ = iv;
} }
} }
...@@ -281,9 +280,12 @@ class WarpAccessRewriter : protected StmtExprMutator { ...@@ -281,9 +280,12 @@ class WarpAccessRewriter : protected StmtExprMutator {
// in this access pattern. // in this access pattern.
std::pair<PrimExpr, PrimExpr> SplitIndexByGroup(const PrimExpr& index) { std::pair<PrimExpr, PrimExpr> SplitIndexByGroup(const PrimExpr& index) {
if (index.dtype().lanes() != 1) { if (index.dtype().lanes() != 1) {
PrimExpr base, local_index, group; PrimExpr local_index, group;
CHECK(GetRamp1Base(index, index.dtype().lanes(), &base));
std::tie(local_index, group) = SplitIndexByGroup(base); arith::PVar<PrimExpr> base;
CHECK(arith::ramp(base, 1, index.dtype().lanes()).Match(index));
std::tie(local_index, group) = SplitIndexByGroup(base.Eval());
local_index = local_index =
RampNode::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes()); RampNode::make(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes());
return std::make_pair(local_index, group); return std::make_pair(local_index, group);
......
...@@ -326,13 +326,14 @@ class StorageFlattener : public StmtExprMutator { ...@@ -326,13 +326,14 @@ class StorageFlattener : public StmtExprMutator {
<< "Prefetch dim should be the same as buffer dim"; << "Prefetch dim should be the same as buffer dim";
int block_size = 1, int block_size = 1,
elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(), elem_cnt = cache_line_size_ / e.buffer->dtype.bytes();
shape = 0;
int starts = op->bounds.size() - 1; int starts = op->bounds.size() - 1;
while (starts > 0 && arith::GetConstInt(e.buffer->shape[starts], &shape)
&& elem_cnt >= block_size * shape) { while (starts > 0) {
block_size *= shape; auto* shape_as_int = e.buffer->shape[starts].as<IntImmNode>();
if (shape_as_int == nullptr || block_size * shape_as_int->value > elem_cnt) break;
block_size *= static_cast<int>(shape_as_int->value);
starts--; starts--;
} }
PrimExpr stride(elem_cnt / block_size); PrimExpr stride(elem_cnt / block_size);
......
...@@ -51,16 +51,13 @@ class LoopUnroller : public StmtExprMutator { ...@@ -51,16 +51,13 @@ class LoopUnroller : public StmtExprMutator {
Stmt VisitStmt_(const AttrStmtNode* op) final { Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == "pragma_auto_unroll_max_step") { if (op->attr_key == "pragma_auto_unroll_max_step") {
int value = 0; int value = static_cast<int>(Downcast<Integer>(op->value)->value);
CHECK(arith::GetConstInt(op->value, &value));
std::swap(value, auto_max_step_); std::swap(value, auto_max_step_);
Stmt ret = this->VisitStmt(op->body); Stmt ret = this->VisitStmt(op->body);
std::swap(value, auto_max_step_); std::swap(value, auto_max_step_);
return ret; return ret;
} else if (op->attr_key == "pragma_unroll_explicit") { } else if (op->attr_key == "pragma_unroll_explicit") {
int value = 0; bool explicit_unroll = Downcast<Integer>(op->value)->value;
CHECK(arith::GetConstInt(op->value, &value));
bool explicit_unroll = value;
std::swap(explicit_unroll, explicit_unroll_); std::swap(explicit_unroll, explicit_unroll_);
Stmt ret = this->VisitStmt(op->body); Stmt ret = this->VisitStmt(op->body);
std::swap(explicit_unroll, explicit_unroll_); std::swap(explicit_unroll, explicit_unroll_);
......
...@@ -519,12 +519,11 @@ class LoopVectorizer : public StmtMutator { ...@@ -519,12 +519,11 @@ class LoopVectorizer : public StmtMutator {
Stmt VisitStmt_(const ForNode* op) final { Stmt VisitStmt_(const ForNode* op) final {
if (op->for_type == ForType::Vectorized) { if (op->for_type == ForType::Vectorized) {
CHECK(is_zero(op->min)); CHECK(is_zero(op->min));
int lanes = 0; auto* extent_as_int = op->extent.as<IntImmNode>();
bool succ = arith::GetConstInt(op->extent, &lanes); if (!extent_as_int || extent_as_int->value < 1) {
if (!succ || lanes < 1) {
LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent; LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent;
} }
return Vectorizer(op->loop_var, lanes)(op->body); return Vectorizer(op->loop_var, static_cast<int>(extent_as_int->value))(op->body);
} else { } else {
return StmtMutator::VisitStmt_(op); return StmtMutator::VisitStmt_(op);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment