Commit e15aae2b by Tianqi Chen Committed by GitHub

[SCHEDULE][PASS] Enable Warp memory and lower to shuffle (#1050)

* [SCHEDULE][PASS] Enable Warp memory and lower to shuffle

* OpenCL dispatches for now to intel shuffle
parent cc71d505
...@@ -412,6 +412,14 @@ constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered"; ...@@ -412,6 +412,14 @@ constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered";
*/ */
constexpr const char* tvm_storage_sync = "tvm_storage_sync"; constexpr const char* tvm_storage_sync = "tvm_storage_sync";
/*! /*!
* \brief See pseudo code
*
* Type tvm_warp_shuffle(Type value, warp_id) {
* return (value passed in by warp indicated by warp_id);
* }
*/
constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle";
/*!
* \brief Initialize the global barrier. * \brief Initialize the global barrier.
* Call this at beginning of kernel that need global barrier. * Call this at beginning of kernel that need global barrier.
*/ */
......
...@@ -408,6 +408,15 @@ LoweredFunc ThreadSync(LoweredFunc stmt, std::string storage_scope); ...@@ -408,6 +408,15 @@ LoweredFunc ThreadSync(LoweredFunc stmt, std::string storage_scope);
LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size); LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
/*! /*!
* \brief Lower warp memory in stmt.
* \param f The device function to be lowered.
* \param warp_size the size of warp where no sync is needed.
* this function will only take in effect if warp_size is bigger than one.
* \return Transformed function.
*/
LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);
/*!
* \brief Lower packed function call. * \brief Lower packed function call.
* \param f The function to be lowered. * \param f The function to be lowered.
* \return Transformed function. * \return Transformed function.
......
...@@ -450,6 +450,10 @@ def build(sch, ...@@ -450,6 +450,10 @@ def build(sch,
else: else:
raise ValueError("unknown function type %d" % func.func_type) raise ValueError("unknown function type %d" % func.func_type)
for i, func in enumerate(fdevice):
warp_size = target.thread_warp_size
fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)
if "gpu" in target.keys and not fdevice: if "gpu" in target.keys and not fdevice:
warnings.warn( warnings.warn(
"Specified target %s, but cannot find device code, did you do bind?" % target) "Specified target %s, but cannot find device code, did you do bind?" % target)
......
...@@ -125,6 +125,7 @@ REGISTER_PASS2(SplitPipeline); ...@@ -125,6 +125,7 @@ REGISTER_PASS2(SplitPipeline);
REGISTER_PASS2(LiftAttrScope); REGISTER_PASS2(LiftAttrScope);
REGISTER_PASS1(NarrowChannelAccess); REGISTER_PASS1(NarrowChannelAccess);
REGISTER_PASS2(LowerThreadAllreduce); REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerWarpMemory);
REGISTER_PASS2(LowerIntrin); REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1(LowerTVMBuiltin); REGISTER_PASS1(LowerTVMBuiltin);
REGISTER_PASS1(CombineContextCall); REGISTER_PASS1(CombineContextCall);
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <iomanip> #include <iomanip>
#include <cctype> #include <cctype>
#include "./codegen_c.h" #include "./codegen_c.h"
#include "../pass/ir_util.h"
#include "../arithmetic/compute_expr.h" #include "../arithmetic/compute_expr.h"
namespace tvm { namespace tvm {
...@@ -544,15 +545,6 @@ void CodeGenC::PrintVecBinaryOp( ...@@ -544,15 +545,6 @@ void CodeGenC::PrintVecBinaryOp(
} }
} }
inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) {
const Ramp* r = index.as<Ramp>();
if (!r) return false;
if (!is_one(r->stride)) return false;
CHECK_EQ(r->lanes, lanes);
*base = r->base;
return true;
}
void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
int lanes = op->type.lanes(); int lanes = op->type.lanes();
// delcare type. // delcare type.
...@@ -563,7 +555,7 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*) ...@@ -563,7 +555,7 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
CHECK(is_one(op->predicate)) CHECK(is_one(op->predicate))
<< "predicated load is not supported"; << "predicated load is not supported";
Expr base; Expr base;
if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) { if (GetRamp1Base(op->index, op->type.lanes(), &base)) {
std::string ref = GetVecLoad(op->type, op->buffer_var.get(), base); std::string ref = GetVecLoad(op->type, op->buffer_var.get(), base);
os << ref; os << ref;
} else { } else {
...@@ -617,7 +609,7 @@ void CodeGenC::VisitStmt_(const Store* op) { ...@@ -617,7 +609,7 @@ void CodeGenC::VisitStmt_(const Store* op) {
CHECK(is_one(op->predicate)) CHECK(is_one(op->predicate))
<< "Predicated store is not supported"; << "Predicated store is not supported";
Expr base; Expr base;
if (TryGetRamp1Base(op->index, t.lanes(), &base)) { if (GetRamp1Base(op->index, t.lanes(), &base)) {
std::string value = this->PrintExpr(op->value); std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer_var.get(), t, base, value); this->PrintVecStore(op->buffer_var.get(), t, base, value);
} else { } else {
......
...@@ -49,6 +49,12 @@ struct CUDAPopcount { ...@@ -49,6 +49,12 @@ struct CUDAPopcount {
} }
}; };
struct CUDAShuffle {
std::string operator()(Type t, std::string name) const {
return "__shfl";
}
};
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>); .set_body(DispatchExtern<CUDAFastMath>);
...@@ -67,6 +73,10 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow") ...@@ -67,6 +73,10 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
.set_body(DispatchExtern<CUDAPopcount>); .set_body(DispatchExtern<CUDAPopcount>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle")
.set_body(DispatchExtern<CUDAShuffle>);
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -27,6 +27,17 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow") ...@@ -27,6 +27,17 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount") TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
// There is no warp shuffle instruction in standard OpenCL
// When shuffle is used, we assume it is intel's shuffle extension
struct IntelShuffle {
std::string operator()(Type t, std::string name) const {
return "intel_sub_group_shuffle";
}
};
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle")
.set_body(DispatchExtern<IntelShuffle>);
} // namespace intrin } // namespace intrin
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -161,6 +161,23 @@ inline int GetTempAllocaAlignment(Type type, int32_t const_size) { ...@@ -161,6 +161,23 @@ inline int GetTempAllocaAlignment(Type type, int32_t const_size) {
} }
return align; return align;
} }
/*!
* \brief Pattern match index to Ramp with stride=1
* This is a common pattern in continuous memory load.
* \param index The index formula
* \param lanes number of lanes in the ramp
* \param base The result base.
* \return true if pattern match success and store the base to base.
*/
inline bool GetRamp1Base(Expr index, int lanes, Expr *base) {
const Ramp* r = index.as<Ramp>();
if (!r) return false;
if (!is_one(r->stride)) return false;
CHECK_EQ(r->lanes, lanes);
*base = r->base;
return true;
}
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
#endif // TVM_PASS_IR_UTIL_H_ #endif // TVM_PASS_IR_UTIL_H_
/*!
* Copyright (c) 2018 by Contributors
*
* Lower warp memory to use local memory
* and shuffle intrinsics.
*
* \file lower_warp_memory.cc
*/
// Thanks to Andrew Adams and Vinod Grover for
// explaining the concept of warp shuffle.
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_pass.h>
#include <unordered_set>
#include "./ir_util.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace ir {
// Rewrite Rule
//
// There is no special warp memory in most GPUs.
// Instead, we can stripe the data into threads
// and store the data into local memory.
//
// This requires us to do the following rewriting:
// - Rewrite allocation to use local memory.
// - Rewrite store of warp memory to local store.
// - Rewrite load of waro memory to local plus a shuffle.
//
// Define a generic shuffle instrinsic warp_shuffle(data, warp_index).
// We can use the following rewriting rule
//
// Before rewrite,
//
// alloc warp warp_mem[n * warp_size * m]
// store warp_mem[m * warp_index + (warp_size * m) * y + x]
// load warp_mem[m * z + (warp_size * m) * y + x]
// subject to x \in [0, m), y \in [0, n)
//
// After rewrite:
//
// alloc local local_mem[n * m]
// store warp_mem[m * y + x]
// warp_shuffle(load warp_mem[m * y + x], z)
// subject to (m * y + x) is invariant to warp_index
// Algorithm
//
// To implement this rewrite rule, we can do the follow step:
// For each warp memory alloc
// - Use linear pattern detector on load index to find m
// - Deduce n given warp_size and alloc size
// - Now that we have m, n, warp_size, we can proceed with the rewrite
// Visitor to find m in pattern
// store warp_mem[m * warp_index + (warp_size * m) * y + x]
class WarpStoreCoeffFinder : private IRVisitor {
public:
WarpStoreCoeffFinder(const Variable* buffer,
Var warp_index)
: buffer_(buffer), warp_index_(warp_index) {
}
// find the warp co-efficient in the statement given the warp size
int Find(const Stmt& stmt) {
this->Visit(stmt);
return warp_coeff_;
}
private:
/// Visitor implementation
void Visit_(const Store *op) final {
if (op->buffer_var.get() == buffer_) {
if (op->value.type().lanes() == 1) {
UpdatePattern(op->index);
} else {
Expr base;
CHECK(GetRamp1Base(op->index, op->value.type().lanes(), &base))
<< "LowerWarpMemory failed due to store index=" << op->index
<< ", can only handle continuous store";
UpdatePattern(base);
}
} else {
IRVisitor::Visit_(op);
}
}
void UpdatePattern(const Expr& index) {
Array<Expr> m =
arith::DetectLinearEquation(index, {warp_index_});
CHECK_EQ(m.size(), 2U)
<< "LowerWarpMemory failed due to store index=" << index;
int coeff;
CHECK(arith::GetConstInt(ir::Simplify(m[0]), &coeff) && coeff > 0)
<< "LowerWarpMemory failed due to store index=" << index
<< ", require positive constant coefficient on warp index";
if (warp_coeff_ != 0) {
CHECK_EQ(warp_coeff_, coeff)
<< "LowerWarpMemory failed due to two different store coefficient to warp index";
} else {
warp_coeff_ = coeff;
}
}
// The buffer variable
const Variable* buffer_;
// the warp index
Var warp_index_;
// the coefficient
int warp_coeff_{0};
};
// Visitor to find the warp index
class WarpIndexFinder : private IRVisitor {
public:
explicit WarpIndexFinder(int warp_size)
: warp_size_(warp_size) {
}
// find the warp co-efficient in the statement given the warp size
IterVar Find(const Stmt& stmt) {
this->Visit(stmt);
CHECK(warp_index_.defined())
<< "Cannot find warp index(threadIdx.x) within the scope of warp memory";
return warp_index_;
}
private:
void Visit(const NodeRef &node) final {
if (warp_index_.defined()) return;
IRVisitor::Visit(node);
}
/// Visitor implementation
void Visit_(const AttrStmt *op) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv(op->node.node_);
if (iv->thread_tag == "threadIdx.x") {
int value;
CHECK(arith::GetConstInt(op->value, &value) &&
value == warp_size_)
<< "Expect threadIdx.x 's size to be equal to warp size("
<< warp_size_ << ")" << " to enable warp memory"
<< " but get " << op->value << " instead";
warp_index_ = iv;
}
}
IRVisitor::Visit_(op);
}
// warp size
int warp_size_{0};
// the warp index
IterVar warp_index_{nullptr};
};
// Mutator to change the read pattern
class WarpAccessRewriter : protected IRMutator {
public:
explicit WarpAccessRewriter(int warp_size)
: warp_size_(warp_size) {}
// Rewrite the allocate statement which transforms
// warp memory to local memory.
Stmt Rewrite(const Allocate* op, const Stmt& stmt) {
buffer_ = op->buffer_var.get();
int alloc_size = op->constant_allocation_size();
CHECK_GT(alloc_size, 0)
<< "warp memory only support constant alloc size";
alloc_size *= op->type.lanes();
warp_index_ = WarpIndexFinder(warp_size_).Find(op->body)->var;
warp_coeff_ = WarpStoreCoeffFinder(
buffer_, warp_index_).Find(op->body);
CHECK_EQ(alloc_size % (warp_size_ * warp_coeff_), 0)
<< "Warp memory must be multiple of warp size";
warp_group_ = alloc_size / (warp_size_ * warp_coeff_);
return Allocate::make(
op->buffer_var,
op->type,
{make_const(Int(32), alloc_size / warp_size_)},
op->condition,
this->Mutate(op->body));
}
protected:
Expr Mutate_(const Variable* op, const Expr& expr) {
CHECK(op != buffer_)
<< "Cannot access address of warp memory directly";
return IRMutator::Mutate_(op, expr);
}
Stmt Mutate_(const Store* op, const Stmt& stmt) {
if (op->buffer_var.get() == buffer_) {
Expr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
return Store::make(op->buffer_var, op->value, local_index, op->predicate);
} else {
return IRMutator::Mutate_(op, stmt);
}
}
Expr Mutate_(const Load* op, const Expr& expr) {
if (op->buffer_var.get() == buffer_) {
Expr local_index, group;
std::tie(local_index, group) = SplitIndexByGroup(op->index);
// invariance: local index must do not contain warp id
CHECK(!ExprUseVar(local_index, {warp_index_.get()}))
<< "LowerWarpMemory failed to rewrite load to shuffle for index "
<< op->index << " local_index=" << local_index;
Expr load_value = Load::make(
op->type, op->buffer_var, local_index, op->predicate);
return Call::make(load_value.type(),
intrinsic::tvm_warp_shuffle,
{load_value, group},
Call::Intrinsic);
} else {
return IRMutator::Mutate_(op, expr);
}
}
// Split the index to the two component
// <local_index, source_index>
// local index is the index in the local
// source index is the corresponding source index
// in this access pattern.
std::pair<Expr, Expr> SplitIndexByGroup(const Expr& index) {
if (index.type().lanes() != 1) {
Expr base, local_index, group;
CHECK(GetRamp1Base(index, index.type().lanes(), &base));
std::tie(local_index, group) = SplitIndexByGroup(base);
local_index =
Ramp::make(local_index, make_const(local_index.type(), 1), index.type().lanes());
return std::make_pair(local_index, group);
}
Expr m = make_const(index.type(), warp_coeff_);
Range rng = Range::make_by_min_extent(
make_zero(index.type()), make_const(index.type(), warp_size_));
Map<Var, Range> vrange({{warp_index_, rng}});
// simple case, warp index is on the highest.
if (warp_group_ == 1) {
Expr x = Simplify(index % m, vrange);
Expr z = Simplify(index / m, vrange);
return std::make_pair(x, z);
} else {
Expr x = Simplify(index % m, vrange);
Expr y = index / make_const(index.type(), warp_coeff_ * warp_size_);
y = y * m + x;
Expr z = index % make_const(index.type(), warp_coeff_ * warp_size_) / m;
return std::make_pair(Simplify(y, vrange), Simplify(z, vrange));
}
}
private:
// the warp size
int warp_size_{0};
// The buffer variable
const Variable* buffer_;
// Warp index
Var warp_index_;
// the coefficient m
int warp_coeff_{0};
// the coefficient n
int warp_group_{0};
};
// Mutator to change the read pattern
class WarpMemoryRewriter : private IRMutator {
public:
explicit WarpMemoryRewriter(int warp_size)
: warp_size_(warp_size) {
}
Stmt Rewrite(Stmt stmt) {
if (warp_size_ == 1) return stmt;
return this->Mutate(stmt);
}
private:
Stmt Mutate_(const Allocate* op, const Stmt& stmt) {
if (warp_buffer_.count(op->buffer_var.get())) {
WarpAccessRewriter rewriter(warp_size_);
return rewriter.Rewrite(op, stmt);
} else {
return IRMutator::Mutate_(op, stmt);
}
}
Stmt Mutate_(const AttrStmt* op, const Stmt& stmt) {
using runtime::StorageScope;
if (op->attr_key == attr::storage_scope) {
const Variable* buf = op->node.as<Variable>();
StorageScope scope = StorageScope::make(op->value.as<StringImm>()->value);
if (scope.rank == runtime::StorageRank::kWarp) {
warp_buffer_.insert(buf);
Stmt ret = IRMutator::Mutate_(op, stmt);
op = ret.as<AttrStmt>();
return AttrStmt::make(
op->node, op->attr_key, StringImm::make("local"), op->body);
}
}
return IRMutator::Mutate_(op, stmt);
}
int warp_size_{0};
std::unordered_set<const Variable*> warp_buffer_;
};
LoweredFunc
LowerWarpMemory(LoweredFunc f, int warp_size) {
CHECK_EQ(f->func_type, kDeviceFunc);
auto n = std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = WarpMemoryRewriter(warp_size).Rewrite(n->body);
return LoweredFunc(n);
}
} // namespace ir
} // namespace tvm
...@@ -42,7 +42,16 @@ bool NeedRelax(const IterVar& iv, ...@@ -42,7 +42,16 @@ bool NeedRelax(const IterVar& iv,
if (tag.length() == 0 || tag == "pipeline") { if (tag.length() == 0 || tag == "pipeline") {
return !found_attach; return !found_attach;
} }
return static_cast<int>(scope.rank) <= ThreadScope::make(tag).rank; ThreadScope ts = ThreadScope::make(tag);
// When there is warp memory
// threadIdx.x must be set to be warp index.
if (scope.rank == StorageRank::kWarp &&
ts.rank == 1 &&
ts.dim_index == 0) {
return true;
}
return static_cast<int>(scope.rank) <= ts.rank;
} }
// infer storage scope, if not given // infer storage scope, if not given
......
import tvm import tvm
from tvm.contrib import nvcc
import numpy as np import numpy as np
import time import time
...@@ -155,7 +156,46 @@ def test_add(): ...@@ -155,7 +156,46 @@ def test_add():
run("uint64") run("uint64")
def try_warp_memory():
"""skip this in default test because it require higher arch"""
m = 128
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i] + 3, name='B')
warp_size = 32
s = tvm.create_schedule(B.op)
AA = s.cache_read(A, "warp", [B])
xo, xi = s[B].split(B.op.axis[0], warp_size * 2)
xi0, xi1 = s[B].split(xi, factor=warp_size)
tx = tvm.thread_axis("threadIdx.x")
s[B].bind(xi1, tx)
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[AA].compute_at(s[B], xo)
xo, xi = s[AA].split(s[AA].op.axis[0], warp_size)
s[AA].bind(xi, tx)
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx")
return ptx
# one line to build the function.
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("skip because %s is not enabled.." % device)
return
f = tvm.build(s, [A, B], device)
a = tvm.nd.array((np.random.uniform(size=m) * 256).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(m, dtype=B.dtype), ctx)
f(a, b)
np.testing.assert_allclose(
b.asnumpy(), a.asnumpy() + 3, rtol=1e-6)
check_device("cuda")
if __name__ == "__main__": if __name__ == "__main__":
try_warp_memory()
test_add() test_add()
test_log_pow_llvm() test_log_pow_llvm()
test_exp() test_exp()
......
import tvm
def test_lower_warp_mem():
m = 128
A = tvm.placeholder((m,), name='A')
B = tvm.compute((m,), lambda i: A[i] + 3, name='B')
s = tvm.create_schedule(B.op)
AA = s.cache_read(A, "warp", [B])
xo, xi = s[B].split(B.op.axis[0], 32)
xi0, xi1 = s[B].split(xi, factor=16)
tx = tvm.thread_axis("threadIdx.x")
s[B].bind(xi1, tx)
s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
s[AA].compute_at(s[B], xo)
xo, xi = s[AA].split(s[AA].op.axis[0], 16)
s[AA].bind(xi, tx)
f = tvm.lower(s, [A, B])
fhost, fdevice = tvm.ir_pass.SplitHostDevice(f)
fdevice = tvm.ir_pass.LowerWarpMemory(fdevice, 16)
assert(fdevice.body.body.value.value == "local")
assert(fdevice.body.body.body.extents[0].value == 2)
if __name__ == "__main__":
test_lower_warp_mem()
...@@ -53,6 +53,29 @@ def test_bound3(): ...@@ -53,6 +53,29 @@ def test_bound3():
assert(bounds[A1.op.axis[0]].extent.value==32) assert(bounds[A1.op.axis[0]].extent.value==32)
assert(bounds[A1.op.axis[1]].extent.value==16) assert(bounds[A1.op.axis[1]].extent.value==16)
def test_bound_warp():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')
A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')
s = tvm.create_schedule(A2.op)
s[A1].set_scope("warp")
xo, xi = s[A2].split(A2.op.axis[0], 32)
xi0, xi1 = s[A2].split(xi, factor=16)
tx = tvm.thread_axis("threadIdx.x")
s[A2].bind(xi1, tx)
s[A2].bind(xi0, tvm.thread_axis("threadIdx.y"))
y = s[A2].op.axis[1]
s[A1].compute_at(s[A2], y)
xo, xi = s[A1].split(s[A1].op.axis[0], factor=16)
s[A1].bind(xi, tx)
bounds = tvm.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
assert(bounds[A1.op.axis[0]].extent.value==16)
def test_bound_scan(): def test_bound_scan():
m = tvm.var("m") m = tvm.var("m")
n = tvm.var("n") n = tvm.var("n")
...@@ -249,3 +272,4 @@ if __name__ == "__main__": ...@@ -249,3 +272,4 @@ if __name__ == "__main__":
test_bound_conv1d() test_bound_conv1d()
test_bound2() test_bound2()
test_gemm_bound() test_gemm_bound()
test_bound_warp()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment