Unverified Commit 37e57548 by yongfeng-nv Committed by GitHub

Improve IntervalSet's floormod (#5367)

parent 4a3fece7
...@@ -138,8 +138,9 @@ class ConstIntBoundAnalyzer { ...@@ -138,8 +138,9 @@ class ConstIntBoundAnalyzer {
* *
* \param var The variable. * \param var The variable.
* \param range The range we bind to. * \param range The range we bind to.
* \param override Whether we allow overriding an existing var's range.
*/ */
TVM_DLL void Bind(const Var& var, const Range& range); TVM_DLL void Bind(const Var& var, const Range& range, bool override = false);
private: private:
friend class Analyzer; friend class Analyzer;
...@@ -411,8 +412,9 @@ class TVM_DLL Analyzer { ...@@ -411,8 +412,9 @@ class TVM_DLL Analyzer {
* *
* \param var The variable. * \param var The variable.
* \param expr The expression we bind to. * \param expr The expression we bind to.
* \param override Whether we allow overriding an existing var's expression.
*/ */
void Bind(const Var& var, const PrimExpr& expr); void Bind(const Var& var, const PrimExpr& expr, bool override = false);
/*! /*!
* \brief Notify all the sub-analyzers that var * \brief Notify all the sub-analyzers that var
* is created and binded to a range. * is created and binded to a range.
...@@ -421,14 +423,16 @@ class TVM_DLL Analyzer { ...@@ -421,14 +423,16 @@ class TVM_DLL Analyzer {
* *
* \param var The variable. * \param var The variable.
* \param range The range we bind to. * \param range The range we bind to.
* \param override Whether we allow overriding an existing var's expression.
*/ */
void Bind(const Var& var, const Range& range); void Bind(const Var& var, const Range& range, bool override = false);
/*! /*!
* \brief Bind all the vars in the Map * \brief Bind all the vars in the Map
* *
* \param variables The {variable -> range} map. * \param variables The {variable -> range} map.
* \param override Whether we allow overriding an existing var's expression.
*/ */
void Bind(const Map<Var, Range>& variables); void Bind(const Map<Var, Range>& variables, bool override = false);
/*! /*!
* \brief Whether can we prove expr >= val. * \brief Whether can we prove expr >= val.
...@@ -443,6 +447,19 @@ class TVM_DLL Analyzer { ...@@ -443,6 +447,19 @@ class TVM_DLL Analyzer {
*/ */
bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound); bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
/*! /*!
* \brief Whether can we prove expr < val.
* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param expr The expression.
* \param upper_bound The upper bound.
* \return Whether we can prove it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveLess(const PrimExpr& expr, int64_t upper_bound);
/*!
* \brief Whether can we prove condition. * \brief Whether can we prove condition.
* *
* \param cond The expression to be proved. * \param cond The expression to be proved.
......
...@@ -153,6 +153,13 @@ class IntSet : public ObjectRef { ...@@ -153,6 +153,13 @@ class IntSet : public ObjectRef {
// Integer set legacy API. // Integer set legacy API.
//------------------------------------------------ //------------------------------------------------
/*! /*!
* \brief Convert std::unordered_map<const VarNode*, IntSet> to Map<Var, IntSet>
*
* \param dom_map The domain map to convert.
* \return The converted map.
*/
Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains all possible values of * \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables. * e given the domain of each iteration variables.
* *
...@@ -160,8 +167,7 @@ class IntSet : public ObjectRef { ...@@ -160,8 +167,7 @@ class IntSet : public ObjectRef {
* \param dom_map The domain of each variable. * \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e. * \return An integer set that can cover all the possible values of e.
*/ */
IntSet EvalSet(PrimExpr e, IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map);
const Map<IterVar, IntSet>& dom_map);
/*! /*!
* \brief Same as EvalSet, but takes unordered_map * \brief Same as EvalSet, but takes unordered_map
* *
...@@ -171,7 +177,6 @@ IntSet EvalSet(PrimExpr e, ...@@ -171,7 +177,6 @@ IntSet EvalSet(PrimExpr e,
*/ */
IntSet EvalSet(PrimExpr e, IntSet EvalSet(PrimExpr e,
const std::unordered_map<const tir::VarNode*, IntSet>& dom_map); const std::unordered_map<const tir::VarNode*, IntSet>& dom_map);
/*! /*!
* \brief Find an symbolic integer set that contains is union over * \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map. * all the possible conditional values in dom_map.
...@@ -202,7 +207,6 @@ IntSet EvalSet(IntSet s, ...@@ -202,7 +207,6 @@ IntSet EvalSet(IntSet s,
*/ */
IntSet EvalSet(Range r, IntSet EvalSet(Range r,
const std::unordered_map<const VarNode*, IntSet>& dom_map); const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */ /*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectHash, ObjectEqual>; using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectHash, ObjectEqual>;
/*! /*!
......
...@@ -36,31 +36,31 @@ Analyzer::Analyzer() ...@@ -36,31 +36,31 @@ Analyzer::Analyzer()
int_set(this) { int_set(this) {
} }
void Analyzer::Bind(const Var& var, const PrimExpr& expr) { void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) {
PrimExpr new_expr = expr; PrimExpr new_expr = expr;
new_expr = this->canonical_simplify(new_expr); new_expr = this->canonical_simplify(new_expr);
new_expr = this->rewrite_simplify(new_expr); new_expr = this->rewrite_simplify(new_expr);
this->const_int_bound.Update(var, this->const_int_bound(new_expr)); this->const_int_bound.Update(var, this->const_int_bound(new_expr), override);
this->modular_set.Update(var, this->modular_set(new_expr)); this->modular_set.Update(var, this->modular_set(new_expr), override);
this->rewrite_simplify.Update(var, new_expr); this->rewrite_simplify.Update(var, new_expr, override);
this->canonical_simplify.Update(var, new_expr); this->canonical_simplify.Update(var, new_expr, override);
} }
void Analyzer::Bind(const Var& var, const Range& range) { void Analyzer::Bind(const Var& var, const Range& range, bool override) {
CHECK(range.defined()); CHECK(range.defined());
if (tir::is_one(range->extent)) { if (tir::is_one(range->extent)) {
this->Bind(var, range->min); this->Bind(var, range->min, override);
} else { } else {
this->const_int_bound.Bind(var, range); this->const_int_bound.Bind(var, range, override);
} }
// skip modular_set // skip modular_set
// skip rewrite simplify // skip rewrite simplify
} }
void Analyzer::Bind(const Map<Var, Range>& variables) { void Analyzer::Bind(const Map<Var, Range>& variables, bool override) {
for (const auto& iter : variables) { for (const auto& iter : variables) {
this->Bind(iter.first, iter.second); this->Bind(iter.first, iter.second, override);
} }
} }
...@@ -92,6 +92,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { ...@@ -92,6 +92,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) {
return false; return false;
} }
bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) {
if (const auto* ptr = expr.as<tir::IntImmNode>()) {
return ptr->value < upper_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
if (bd->max_value < upper_bound) return true;
return false;
}
bool Analyzer::CanProve(const PrimExpr& expr) { bool Analyzer::CanProve(const PrimExpr& expr) {
if (const auto* ptr = expr.as<IntImmNode>()) { if (const auto* ptr = expr.as<IntImmNode>()) {
return ptr->value != 0; return ptr->value != 0;
......
...@@ -99,13 +99,13 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -99,13 +99,13 @@ class ConstIntBoundAnalyzer::Impl :
} }
}; };
void Bind(const Var& var, const Range& range) { void Bind(const Var& var, const Range& range, bool override) {
Entry a = VisitExpr(range->min); Entry a = VisitExpr(range->min);
Entry b = VisitExpr(range->extent); Entry b = VisitExpr(range->extent);
Entry ret; Entry ret;
ret.min_value = a.min_value; ret.min_value = a.min_value;
ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1)); ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1));
Update(var, ret, false); Update(var, ret, override);
} }
void Update(const Var& var, void Update(const Var& var,
...@@ -150,10 +150,12 @@ class ConstIntBoundAnalyzer::Impl : ...@@ -150,10 +150,12 @@ class ConstIntBoundAnalyzer::Impl :
const PrimExprNode* op = expr.as<PrimExprNode>(); const PrimExprNode* op = expr.as<PrimExprNode>();
auto val = bound_->find(op); auto val = bound_->find(op);
if (val != bound_->end()) { if (val != bound_->end()) {
CHECK(val->second->min_value == res.min_value && auto everything = Everything(op->dtype);
val->second->max_value == res.max_value) CHECK(
<< "Detected bound for " << expr (val->second->min_value == res.min_value && val->second->max_value == res.max_value) ||
<< "conflicts with memorization"; (val->second->min_value == everything.min_value &&
val->second->max_value == everything.max_value))
<< "Detected bound for " << expr << "conflicts with memorization";
} }
(*bound_)[op] = ConstIntBound(res.min_value, res.max_value); (*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
} }
...@@ -574,8 +576,8 @@ void ConstIntBoundAnalyzer::Update(const Var& var, ...@@ -574,8 +576,8 @@ void ConstIntBoundAnalyzer::Update(const Var& var,
impl_->Update(var, info, override); impl_->Update(var, info, override);
} }
void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) { void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool override) {
impl_->Bind(var, range); impl_->Bind(var, range, override);
} }
std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) { std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) {
......
...@@ -311,6 +311,16 @@ inline IntervalSet Combine<tir::FloorModNode>(Analyzer* analyzer, ...@@ -311,6 +311,16 @@ inline IntervalSet Combine<tir::FloorModNode>(Analyzer* analyzer,
LOG(FATAL) << "Modular by zero in CombineInterval Mod"; LOG(FATAL) << "Modular by zero in CombineInterval Mod";
} }
if (analyzer->CanProveGreaterEqual(divisor, 0)) { if (analyzer->CanProveGreaterEqual(divisor, 0)) {
if (divisor.as<tir::IntImmNode>()) {
// a mod b = a - (a / b) * b if a_max / b == a_min / b
auto qmax = floordiv(a->max_value, divisor);
auto qmin = floordiv(a->min_value, divisor);
if (analyzer->CanProve(qmax == qmin)) {
auto tmax = a->max_value - divisor * qmin;
auto tmin = a->min_value - divisor * qmin;
return IntervalSet(tmin, tmax);
}
}
return IntervalSet(make_zero(divisor.dtype()), divisor - 1); return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
} else { } else {
PrimExpr bound = abs(divisor) - 1; PrimExpr bound = abs(divisor) - 1;
......
...@@ -231,7 +231,7 @@ void ComputeOpNode::PropBoundToInputs( ...@@ -231,7 +231,7 @@ void ComputeOpNode::PropBoundToInputs(
// undefined behaviour), so we can intersect the estimated set of the argument with the // undefined behaviour), so we can intersect the estimated set of the argument with the
// range expected by the tensor. However, intersection may result in overly complex // range expected by the tensor. However, intersection may result in overly complex
// expressions, so we perform a more relaxed form of intersection. // expressions, so we perform a more relaxed form of intersection.
IntSet arg_intset = EvalSet(call->args[i], dom_map); IntSet arg_intset = analyzer->int_set(call->args[i], ConvertDomMap(dom_map));
const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>(); const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
if (arg_interval) { if (arg_interval) {
PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype()); PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype());
...@@ -239,12 +239,14 @@ void ComputeOpNode::PropBoundToInputs( ...@@ -239,12 +239,14 @@ void ComputeOpNode::PropBoundToInputs(
PrimExpr min_value = arg_interval->min_value; PrimExpr min_value = arg_interval->min_value;
PrimExpr max_value = arg_interval->max_value; PrimExpr max_value = arg_interval->max_value;
// Prefer the shape bounds only when we can prove they are tighter. // Prefer the shape bounds only when we can prove they are tighter.
if (arith::is_neg_inf(min_value) || // We must update bound's ends in pairs. Here is an counter example: shape_i is
analyzer->CanProve(shape_i_min_value >= min_value)) { // [0, 0] and arg_interval is [threadIdx.y, threadIdx.y], where threadIdx.y's range is
// [0, 7]. If we allowed updating one end, the bound would become [threadIdx.y, 0],
// awkward for further analysis.
if ((arith::is_pos_inf(max_value) && arith::is_neg_inf(min_value)) ||
(analyzer->CanProve(shape_i_min_value >= min_value) &&
analyzer->CanProve(shape_i_max_value <= max_value))) {
min_value = shape_i_min_value; min_value = shape_i_min_value;
}
if (arith::is_pos_inf(max_value) ||
analyzer->CanProve(shape_i_max_value <= max_value)) {
max_value = shape_i_max_value; max_value = shape_i_max_value;
} }
dom.data[i].push_back(IntSet::interval(min_value, max_value)); dom.data[i].push_back(IntSet::interval(min_value, max_value));
......
...@@ -137,7 +137,7 @@ void InferRootBound(const Stage& stage, ...@@ -137,7 +137,7 @@ void InferRootBound(const Stage& stage,
Array<IterVar> stage_attach = ctx.attach_path.at(stage->op); Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
// The parent set. // The parent set.
for (const Operation& op : consumers) { for (const Operation& op : consumers) {
std::unordered_map<const VarNode*, IntSet> relax_set; Map<Var, IntSet> relax_set;
std::unordered_map<IterVar, IntSet> up_state; std::unordered_map<IterVar, IntSet> up_state;
bool found_attach = false; bool found_attach = false;
CHECK(ctx.op2stage_.count(op.get())); CHECK(ctx.op2stage_.count(op.get()));
...@@ -176,9 +176,9 @@ void InferRootBound(const Stage& stage, ...@@ -176,9 +176,9 @@ void InferRootBound(const Stage& stage,
<< "InferBound requires every leaf iter var's min equals 0, " << "InferBound requires every leaf iter var's min equals 0, "
<< "call schedule.normalize to achieve this."; << "call schedule.normalize to achieve this.";
if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) { if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
relax_set[iv->var.get()] = IntSet::range(vrange); relax_set.Set(iv->var, IntSet::range(vrange));
if (ctx.bind_map.count(iv)) { if (ctx.bind_map.count(iv)) {
relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange); relax_set.Set(ctx.bind_map.at(iv)->var, IntSet::range(vrange));
} }
} }
} }
...@@ -190,6 +190,9 @@ void InferRootBound(const Stage& stage, ...@@ -190,6 +190,9 @@ void InferRootBound(const Stage& stage,
// Relax if needed. // Relax if needed.
std::unordered_map<const VarNode*, IntSet> dom_map; std::unordered_map<const VarNode*, IntSet> dom_map;
arith::Analyzer analyzer; arith::Analyzer analyzer;
for (auto entry : *rmap) {
analyzer.Bind(entry.first->var, entry.second);
}
for (auto iv : op->root_iter_vars()) { for (auto iv : op->root_iter_vars()) {
Range r; Range r;
if (up_state.count(iv)) { if (up_state.count(iv)) {
...@@ -198,11 +201,13 @@ void InferRootBound(const Stage& stage, ...@@ -198,11 +201,13 @@ void InferRootBound(const Stage& stage,
r = iv->dom; r = iv->dom;
} }
if (relax_set.size() != 0) { if (relax_set.size() != 0) {
dom_map[iv->var.get()] = EvalSet(r, relax_set); dom_map[iv->var.get()] = IntSet::interval(
analyzer.int_set(r->min, relax_set).min(),
analyzer.int_set(r->min + r->extent - 1, relax_set).max());
} else { } else {
dom_map[iv->var.get()] = IntSet::range(r); dom_map[iv->var.get()] = IntSet::range(r);
} }
analyzer.Bind(iv->var, r); analyzer.Bind(iv->var, r, true);
} }
op->PropBoundToInputs(op, &analyzer, dom_map, &tmap); op->PropBoundToInputs(op, &analyzer, dom_map, &tmap);
} }
......
...@@ -579,11 +579,15 @@ std::vector<PrimExpr> MakeBoundCheck( ...@@ -579,11 +579,15 @@ std::vector<PrimExpr> MakeBoundCheck(
PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer); PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);
std::vector<PrimExpr> preds; std::vector<PrimExpr> preds;
std::unordered_map<const VarNode*, IntSet> iset_dmap; Map<Var, IntSet> iset_dmap;
// setup domain map for set analysis // setup domain map for set analysis
for (const auto& kv : dom_map) { for (const auto& kv : dom_map) {
iset_dmap[kv.first->var.get()] = IntSet::range(kv.second); iset_dmap.Set(kv.first->var, IntSet::range(kv.second));
}
for (auto entry : dom_map) {
analyzer.Bind(entry.first->var, entry.second);
} }
for (const IterVar& iv : stage->all_iter_vars) { for (const IterVar& iv : stage->all_iter_vars) {
...@@ -591,7 +595,7 @@ std::vector<PrimExpr> MakeBoundCheck( ...@@ -591,7 +595,7 @@ std::vector<PrimExpr> MakeBoundCheck(
if (bound_state.at(iv)) { if (bound_state.at(iv)) {
Range dom = dom_map.at(iv); Range dom = dom_map.at(iv);
PrimExpr value = value_map.at(iv) - dom->min; PrimExpr value = value_map.at(iv) - dom->min;
PrimExpr vmax = EvalSet(value, iset_dmap).max(); PrimExpr vmax = analyzer.int_set(value, iset_dmap).max();
if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) { if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent); preds.emplace_back(value < dom->extent);
} }
...@@ -603,7 +607,7 @@ std::vector<PrimExpr> MakeBoundCheck( ...@@ -603,7 +607,7 @@ std::vector<PrimExpr> MakeBoundCheck(
CHECK(iv->dom.defined()); CHECK(iv->dom.defined());
if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) { if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) {
PrimExpr value = value_map.at(iv) - iv->dom->min; PrimExpr value = value_map.at(iv) - iv->dom->min;
IntSet s = EvalSet(value, iset_dmap); IntSet s = analyzer.int_set(value, iset_dmap);
PrimExpr vmin = s.min(); PrimExpr vmin = s.min();
PrimExpr vmax = s.max(); PrimExpr vmax = s.max();
// The range of `value` resides in [vmin, vmax] // The range of `value` resides in [vmin, vmax]
......
...@@ -90,6 +90,20 @@ def test_mod(): ...@@ -90,6 +90,20 @@ def test_mod():
flm = tvm.te.floormod flm = tvm.te.floormod
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9)) ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9))
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 5)}, (3, 5))
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(13, 15)}, (3, 5))
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 15)}, (0, 9))
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 11)}, (0, 9))
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(1, 21)}, (0, 9))
floordiv = tvm.te.floordiv
z = te.var("z")
ck.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 3))
ck.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)},
(0, 7))
ck1 = IntSetChecker()
ck1.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 2))
ck1.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, (x*4, x*4+3))
def test_max_min(): def test_max_min():
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
def test_bound_tile_mod():
def compute(M_tiles, N_tiles, factor, dtype):
# Algo
M = M_tiles * factor
N = N_tiles * factor
A = tvm.te.placeholder((N, M), name='A', dtype=dtype)
C = tvm.te.compute((N, M), lambda n, m: A[n, m], name='C')
s = tvm.te.create_schedule(C.op)
return s, A, C
def schedule(s, factor, padding, A, C):
C_local = s.cache_write(C, "local")
n, m = C.op.axis
bn, bm, ni, mi = s[C].tile(n, m, factor, factor)
nio, nii = s[C].split(ni, 2)
n = s[C].fuse(nii, mi)
C_shared = s.cache_write(C, "shared")
bn, bm, ni, mi = C_shared.op.axis
s[C_shared].storage_align(ni, factor * 2, padding)
n, m = s[C].op.axis
bn, bm, ni, mi = s[C].tile(n, m, factor, factor)
s[C].set_scope("global")
niio, niii = s[C].split(ni, 32)
s[C_shared].compute_at(s[C], niio)
return s
s, A, C = compute(2, 2, 128, "float16")
s = schedule(s, 128, 8, A, C)
bounds = tvm.te.schedule.InferBound(s)
check = (bounds[s.stages[2].op.axis[2]].extent == 16)
if(not check):
print(tvm.lower(s, [A, C], simple_mode=True))
assert(check)
if __name__ == "__main__":
test_bound_tile_mod()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment