Commit 2c512ca7 by Tianqi Chen Committed by GitHub

[LLVM] Vectorized load/store (#60)

parent 2111bbf3
......@@ -80,7 +80,7 @@ inline bool GetConstInt(Expr e, int* out) {
} \
uint64_t ua = 0, ub = 0; \
if (GetConst(a, &ua) && GetConst(b, &ub)) { \
return ir::UIntImm::make(a.type(), ua + ub); \
return ir::UIntImm::make(a.type(), ua OP ub); \
} \
template<>
......
......@@ -113,7 +113,7 @@ class ModularEvaluator
private:
const std::unordered_map<
const Variable*, ModularEntry>& mod_map_;
friend struct ModularEntry;
// simplify the base by putting it in range.
static int BaseSimplify(int base, int coeff) {
if (coeff == 0) return base;
......@@ -136,6 +136,15 @@ class ModularEvaluator
}
};
ModularEntry ModularEntry::Add(const ModularEntry& a,
const ModularEntry& b) {
ModularEntry ret;
ret.coeff = ModularEvaluator::ZeroAwareGCD(a.coeff, b.coeff);
ret.base = ModularEvaluator::BaseSimplify(a.base + b.base, ret.coeff);
return ret;
}
ModularEntry EvalModular(
const Expr& e,
const std::unordered_map<const Variable*, ModularEntry>& mod_map) {
......
......@@ -37,6 +37,14 @@ struct ModularEntry {
e.base = 0; e.coeff = 1;
return e;
}
/*!
* \brief Add two modular entries together to get a new modular entry.
* \param a The left operand.
* \param b The right operand.
* \return The combined modular entry.
*/
static ModularEntry Add(const ModularEntry& a,
const ModularEntry& b);
};
/*!
......
......@@ -102,8 +102,14 @@ void CodeGenLLVM::InitGlobalContext() {
gv_mod_ctx_->setInitializer(llvm::Constant::getNullValue(t_void_p_));
}
void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
void CodeGenLLVM::InitFuncState() {
var_map_.clear();
align_map_.clear();
alloc_storage_scope_.clear();
}
void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
this->InitFuncState();
CHECK(!module_->getFunction(f->name))
<< "Function " << f->name << "already exists in module";
std::vector<llvm::Type*> arg_type;
......@@ -163,6 +169,7 @@ class FPassManager : public llvm::legacy::FunctionPassManager {
llvm::legacy::FunctionPassManager::add(p);
}
};
class MPassManager : public llvm::legacy::PassManager {
public:
// override add to allow messaging
......@@ -245,25 +252,26 @@ void CodeGenLLVM::AddAliasInfo(
int base = 0, width = 0;
// create meta-data for alias analysis
// Use a group of binary tree ranges.
const Ramp* ramp = index.as<Ramp>();
if (ramp) {
int base, stride;
if (arith::GetConstInt(ramp->base, &base) &&
arith::GetConstInt(ramp->stride, &stride)) {
int xwith = ramp->lanes * stride;
width = 1;
while (width < xwith) {
width *= 2;
}
while (base % width) {
base -= base % width;
width *= 2;
if (index.defined()) {
const Ramp* ramp = index.as<Ramp>();
if (ramp) {
int base, stride;
if (arith::GetConstInt(ramp->base, &base) &&
arith::GetConstInt(ramp->stride, &stride)) {
int xwith = ramp->lanes * stride;
width = 1;
while (width < xwith) {
width *= 2;
}
while (base % width) {
base -= base % width;
width *= 2;
}
}
} else {
if (arith::GetConstInt(index, &base)) width = 1;
}
} else {
if (arith::GetConstInt(index, &base)) width = 1;
}
llvm::MDNode* meta = md_tbaa_root_;
std::ostringstream buffer_addr;
buffer_addr << buffer;
......@@ -283,12 +291,12 @@ void CodeGenLLVM::AddAliasInfo(
}
llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
llvm::Constant* init = llvm::UndefValue::get(
llvm::Constant* undef = llvm::UndefValue::get(
llvm::VectorType::get(value->getType(), lanes));
llvm::Constant* zero = ConstInt32(0);
value = builder_->CreateInsertElement(init, value, zero);
value = builder_->CreateInsertElement(undef, value, zero);
llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
return builder_->CreateShuffleVector(value, init, mask);
return builder_->CreateShuffleVector(value, undef, mask);
}
llvm::Value* CodeGenLLVM::CreateBufferPtr(
......@@ -684,6 +692,38 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
return nullptr;
}
int CodeGenLLVM::NativeVectorBits(const std::string& storage_scope) const {
// By default, we ask the buffer to be aligned to 64 bytes
return 64 * 8;
}
void CodeGenLLVM::GetAlignment(
Type t, const Variable* buf_var, const Expr& index,
int* p_alignment, int* p_native_bits) {
int& alignment = *p_alignment;
int& native_bits = *p_native_bits;
// The storage scope.
std::string scope;
auto it = alloc_storage_scope_.find(buf_var);
if (it != alloc_storage_scope_.end()) {
scope = it->second;
}
arith::ModularEntry m = EvalModular(index, align_map_);
native_bits = NativeVectorBits(scope);
alignment = t.element_of().bits();
// find alignment
while ((m.coeff & 1) == 0 &&
(m.base & 1) == 0 &&
alignment < native_bits) {
m.coeff /= 2;
m.base /= 2;
alignment *= 2;
}
CHECK_EQ(alignment % 8, 0)
<< "Load from memory that does not align to 8 bits";
alignment /= 8;
}
// visitor overrides
llvm::Value* CodeGenLLVM::VisitExpr_(const Variable* op) {
return GetVarValue(op);
......@@ -849,7 +889,9 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) {
llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) {
llvm::Value* v = MakeValue(op->value);
CHECK(!var_map_.count(op->var.get()));
CHECK(!align_map_.count(op->var.get()));
var_map_[op->var.get()] = v;
align_map_[op->var.get()] = arith::EvalModular(op->value, align_map_);
return MakeValue(op->body);
}
......@@ -872,25 +914,254 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Ramp* op) {
return value;
}
void CodeGenLLVM::Scalarize(
const Expr& e,
std::function<void(int i, llvm::Value* v)> f) {
const Ramp* ramp = e.as<Ramp>();
Type t = e.type();
if (ramp) {
for (int i = 0; i < t.lanes(); ++i) {
Expr offset = arith::ComputeExpr<Add>(
ramp->base,
arith::ComputeExpr<Mul>(ramp->stride, i));
f(i, MakeValue(offset));
}
} else {
llvm::Value* index = MakeValue(e);
for (int i = 0; i < t.lanes(); ++i) {
f(i, builder_->CreateExtractElement(index, ConstInt32(i)));
}
}
}
llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
int lanes = static_cast<int>(vec->getType()->getVectorNumElements());
std::vector<llvm::Constant*> indices;
for (int i = lanes; i != 0; --i) {
indices.push_back(ConstInt32(i - 1));
}
llvm::Constant* undef = llvm::UndefValue::get(vec->getType());
return builder_->CreateShuffleVector(
vec, undef, llvm::ConstantVector::get(indices));
}
llvm::Value* CodeGenLLVM::CreateVecSlice(
llvm::Value* vec, int begin, int lanes) {
int total_lanes = static_cast<int>(vec->getType()->getVectorNumElements());
CHECK_LE(begin + lanes, total_lanes);
if (lanes == total_lanes && begin == 0) return vec;
std::vector<llvm::Constant*> indices;
for (int i = 0; i < lanes; ++i) {
indices.push_back(ConstInt32(begin + i));
}
llvm::Constant* undef = llvm::UndefValue::get(vec->getType());
return builder_->CreateShuffleVector(
vec, undef, llvm::ConstantVector::get(indices));
}
llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
int lanes = static_cast<int>(vec->getType()->getVectorNumElements());
if (target_lanes == lanes) return vec;
CHECK_GT(target_lanes, lanes);
int pad_lanes = target_lanes - lanes;
llvm::Constant* undef = llvm::UndefValue::get(
llvm::VectorType::get(vec->getType()->getVectorElementType(), pad_lanes));
std::vector<llvm::Constant*> indices;
for (int i = 0; i < target_lanes; ++i) {
indices.push_back(ConstInt32(i));
}
return builder_->CreateShuffleVector(
vec, undef, llvm::ConstantVector::get(indices));
}
llvm::Value* CodeGenLLVM::CreateVecConcat(
std::vector<llvm::Value*> vec) {
CHECK_NE(vec.size(), 0U);
int target_lanes = 0;
for (llvm::Value* v : vec) {
target_lanes += static_cast<int>(v->getType()->getVectorNumElements());
}
// tree shape merging
while (vec.size() != 1) {
std::vector<llvm::Value*> merged;
for (size_t i = 0; i < vec.size() - 1; i += 2) {
llvm::Value* v1 = vec[i];
llvm::Value* v2 = vec[i + 1];
int w1 = static_cast<int>(v1->getType()->getVectorNumElements());
int w2 = static_cast<int>(v2->getType()->getVectorNumElements());
int w = std::max(w1, w2);
v1 = CreateVecPad(v1, w);
v2 = CreateVecPad(v2, w);
std::vector<llvm::Constant*> indices;
for (int i = 0; i < w * 2; ++i) {
indices.push_back(ConstInt32(i));
}
merged.push_back(
builder_->CreateShuffleVector(
v1, v2, llvm::ConstantVector::get(indices)));
}
if (vec.size() % 2 == 1) {
merged.push_back(vec.back());
}
vec = merged;
}
return CreateVecSlice(vec[0], 0, target_lanes);
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
Type t = op->type;
CHECK(!t.is_vector());
const Ramp* ramp = op->index.as<Ramp>();
llvm::Value* buf = GetVarValue(op->buffer_var.get());
if (t.is_scalar()) {
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
CreateBufferPtr(
t,
GetVarValue(op->buffer_var.get()),
MakeValue(op->index)),
CreateBufferPtr(t, buf, MakeValue(op->index)),
data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), op->index);
return inst;
} else if (ramp && is_one(ramp->stride)) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base,
&alignment, &native_bits);
int total_lanes = t.lanes();
int step = native_bits / t.bits();
std::vector<llvm::Value*> loads;
for (int offset = 0; offset < total_lanes; offset += step) {
int lanes = std::min(step, total_lanes - offset);
Expr base = arith::ComputeExpr<Add>(
ramp->base, make_const(ramp->base.type(), offset));
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, MakeValue(base));
llvm::Type* vtype = llvm::VectorType::get(
LLVMType(t.element_of()), lanes)->getPointerTo();
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
builder_->CreatePointerCast(ptr, vtype), alignment);
AddAliasInfo(inst, op->buffer_var.get(),
Ramp::make(base, make_const(base.type(), 1), lanes));
loads.push_back(inst);
}
return CreateVecConcat(loads);
} else if (ramp && is_const(ramp->stride, 2)) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base,
&alignment, &native_bits);
arith::ModularEntry e = arith::EvalModular(ramp->base, align_map_);
Type bt = ramp->base.type();
int first_shift, next_shift;
// If it is even base, and native alignments is bigger than twice
// of the type, to ensure safe loading.
if (e.coeff % 2 == 0 &&
e.base % 2 == 0 &&
native_bits >= t.bits() * 2) {
first_shift = 0;
next_shift = 0;
} else if (e.coeff % 2 == 0 && e.base % 2 == 1) {
// odd base, shift both to left.
first_shift = -1;
next_shift = -1;
} else {
// save option, right part, safe option.
first_shift = 0;
next_shift = -1;
}
llvm::Value* first = MakeValue(Load::make(
t, op->buffer_var,
Ramp::make(arith::ComputeExpr<Add>(
ramp->base, make_const(bt, first_shift)),
make_const(bt, 1), ramp->lanes)));
llvm::Value* next = MakeValue(Load::make(
t, op->buffer_var,
Ramp::make(arith::ComputeExpr<Add>(
ramp->base, make_const(bt, ramp->lanes + next_shift)),
make_const(bt, 1), ramp->lanes)));
// shuffle
std::vector<llvm::Constant*> indices;
int target_index = 0;
for (int i = 0; i < ramp->lanes; ++i) {
int idx = first_shift + i;
if (idx == target_index) {
indices.push_back(ConstInt32(i));
target_index += 2;
}
}
for (int i = 0; i < ramp->lanes; ++i) {
int idx = ramp->lanes + next_shift + i;
if (idx == target_index) {
indices.push_back(ConstInt32(i + ramp->lanes));
target_index += 2;
}
}
CHECK_EQ(indices.size(), static_cast<size_t>(ramp->lanes));
return builder_->CreateShuffleVector(
first, next, llvm::ConstantVector::get(indices));
} else if (ramp && is_const(ramp->stride, -1)) {
int lanes = ramp->type.lanes();
Expr neg_ramp = Ramp::make(
arith::ComputeExpr<Sub>(
ramp->base,
make_const(ramp->base.type(), lanes - 1)),
make_const(ramp->base.type(), 1),
lanes);
// load value then flip
llvm::Value* v = MakeValue(Load::make(t, op->buffer_var, neg_ramp));
return CreateVecFlip(v);
} else {
llvm::Value* ret = llvm::UndefValue::get(LLVMType(t));
Scalarize(op->index, [&](int i, llvm::Value* offset) {
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, offset);
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
ptr, data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), Expr());
ret = builder_->CreateInsertElement(ret, inst, ConstInt32(i));
});
return ret;
}
}
// stmts
void CodeGenLLVM::VisitStmt_(const Store* op) {
llvm::Value* value = MakeValue(op->value);
Type t = op->value.type();
const Ramp* ramp = op->index.as<Ramp>();
llvm::Value* buf = GetVarValue(op->buffer_var.get());
if (t.is_scalar()) {
llvm::StoreInst* inst = builder_->CreateAlignedStore(
value,
CreateBufferPtr(t, buf, MakeValue(op->index)),
data_layout_->getTypeAllocSize(value->getType()));
AddAliasInfo(inst, op->buffer_var.get(), op->index);
} else if (ramp && is_one(ramp->stride)) {
int alignment, native_bits;
GetAlignment(t, op->buffer_var.get(), ramp->base,
&alignment, &native_bits);
int total_lanes = t.lanes();
int step = native_bits / t.bits();
// vector store.
for (int offset = 0; offset < total_lanes; offset += step) {
int lanes = std::min(step, total_lanes - offset);
Expr base = arith::ComputeExpr<Add>(
ramp->base, make_const(ramp->base.type(), offset));
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, MakeValue(base));
llvm::Type* vtype = llvm::VectorType::get(
LLVMType(t.element_of()), lanes)->getPointerTo();
llvm::StoreInst* inst = builder_->CreateAlignedStore(
CreateVecSlice(value, offset, lanes),
builder_->CreatePointerCast(ptr, vtype), alignment);
AddAliasInfo(inst, op->buffer_var.get(),
Ramp::make(base, make_const(base.type(), 1), lanes));
}
} else {
LOG(FATAL) << "not yet supported";
return nullptr;
Scalarize(op->index, [&](int i, llvm::Value* offset) {
llvm::Value* ptr = CreateBufferPtr(t.element_of(), buf, offset);
llvm::StoreInst* inst = builder_->CreateAlignedStore(
builder_->CreateExtractElement(value, ConstInt32(i)),
ptr, data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), Expr());
});
}
}
llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
return CreateCallPacked(op);
......@@ -904,24 +1175,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Call* op) {
}
}
// stmts
void CodeGenLLVM::VisitStmt_(const Store* op) {
llvm::Value* value = MakeValue(op->value);
Type t = op->value.type();
CHECK(!t.is_vector());
if (t.is_scalar()) {
llvm::StoreInst* inst = builder_->CreateAlignedStore(
value,
CreateBufferPtr(
t,
GetVarValue(op->buffer_var.get()),
MakeValue(op->index)),
data_layout_->getTypeAllocSize(value->getType()));
AddAliasInfo(inst, op->buffer_var.get(), op->index);
} else {
LOG(FATAL) << "not yet supported";
}
}
void CodeGenLLVM::VisitStmt_(const For* op) {
CHECK(is_zero(op->min));
......@@ -986,6 +1239,11 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
}
void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
if (op->type_key == ir::attr::storage_scope) {
const Variable* v = op->node.as<Variable>();
CHECK(v);
alloc_storage_scope_[v] = op->value.as<StringImm>()->value;
}
this->VisitStmt(op->body);
}
......@@ -1014,7 +1272,9 @@ void CodeGenLLVM::VisitStmt_(const AssertStmt* op) {
void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
llvm::Value* v = MakeValue(op->value);
CHECK(!var_map_.count(op->var.get()));
CHECK(!align_map_.count(op->var.get()));
var_map_[op->var.get()] = v;
align_map_[op->var.get()] = arith::EvalModular(op->value, align_map_);
this->VisitStmt(op->body);
}
void CodeGenLLVM::VisitStmt_(const Block* op) {
......
......@@ -14,6 +14,7 @@
#include <vector>
#include <string>
#include "./llvm_common.h"
#include "../../arithmetic/modular.h"
namespace tvm {
namespace codegen {
......@@ -109,18 +110,29 @@ class CodeGenLLVM :
virtual llvm::Value* CreateCallExtern(const Call* op);
// create call into tvm packed function.
virtual llvm::Value* CreateCallPacked(const Call* op);
// Scalarize e by iterating elements of e.
// f is a callback that takes index and v.
virtual void Scalarize(const Expr& e,
std::function<void(int i, llvm::Value* v)> f);
protected:
/*!
* \param t The original type.
* \return LLVM type of t
*/
llvm::Type* LLVMType(const Type& t) const;
// initialize the function state.
void InitFuncState();
// Get alignment given index.
void GetAlignment(
Type t, const Variable* buf_var, const Expr& index,
int* p_alignment, int* p_native_bits);
// do a scalarize call with f
llvm::Value* CreateScalarizedCall(
const Call* op, llvm::Function* f, const std::vector<llvm::Value*>& args);
// apply optimization on the module.
virtual void Optimize();
// Get the maximim storage align bits of buffer pointer given storage scope.
virtual int NativeVectorBits(const std::string& storage_scope) const;
// The IRBuilder.
using IRBuilder = llvm::IRBuilder<llvm::ConstantFolder, llvm::IRBuilderDefaultInserter>;
// The current function
......@@ -162,6 +174,8 @@ class CodeGenLLVM :
llvm::Function* f_tvm_parallel_for_{nullptr};
// The acting body
llvm::BasicBlock* block_{nullptr};
/*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
private:
// comparison op
......@@ -178,6 +192,11 @@ class CodeGenLLVM :
llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
llvm::Value* GetPackedFuncHandle(const std::string& str);
// Vector concatenation.
llvm::Value* CreateVecSlice(llvm::Value* vec, int begin, int extent);
llvm::Value* CreateVecFlip(llvm::Value* vec);
llvm::Value* CreateVecConcat(std::vector<llvm::Value*> vecs);
llvm::Value* CreateVecPad(llvm::Value* vec, int target_lanes);
// Create parallel for.
void CreateParallelFor(const For* op);
// Create serial for
......@@ -197,6 +216,8 @@ class CodeGenLLVM :
std::unordered_map<const Variable*, llvm::Value*> var_map_;
// global strings
std::unordered_map<std::string, llvm::Constant*> str_map_;
// The alignment information
std::unordered_map<const Variable*, arith::ModularEntry> align_map_;
// The local module_context
llvm::GlobalVariable* gv_mod_ctx_{nullptr};
// global to packed function handle
......
......@@ -355,7 +355,9 @@ class Vectorizer : public IRMutator {
const Ramp* a_ramp = a.as<Ramp>();
if (a.type().lanes() == 1 && b_ramp) {
return Ramp::make(
arith::ComputeExpr<T>(a, b_ramp->base), b_ramp->stride, b_ramp->lanes);
arith::ComputeExpr<T>(a, b_ramp->base),
arith::ComputeExpr<T>(make_zero(b_ramp->stride.type()), b_ramp->stride),
b_ramp->lanes);
}
if (b.type().lanes() == 1 && a_ramp) {
return Ramp::make(
......
......@@ -2,13 +2,15 @@ import tvm
import numpy as np
def test_llvm_add_pipeline():
n = tvm.Var('n')
nn = 1024
n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.Schedule(C.op)
s[C].parallel(C.op.axis[0])
xo, xi = s[C].split(C.op.axis[0], factor=4)
s[C].parallel(xo)
s[C].vectorize(xi)
def check_llvm():
if not tvm.codegen.enabled("llvm"):
return
......@@ -16,16 +18,71 @@ def test_llvm_add_pipeline():
f = tvm.build(s, [A, B, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
n = 1027 * 1024
n = nn
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
for i in range(1000):
f(a, b, c)
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
check_llvm()
def test_llvm_flip_pipeline():
def check_llvm(nn, base):
if not tvm.codegen.enabled("llvm"):
return
n = tvm.convert(nn)
A = tvm.placeholder((n + base), name='A')
C = tvm.compute((n,), lambda i: A(nn + base- i - 1), name='C')
s = tvm.Schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=4)
s[C].parallel(xo)
s[C].vectorize(xi)
# build and invoke the kernel.
f = tvm.build(s, [A, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
n = nn
a = tvm.nd.array(np.random.uniform(size=(n + base)).astype(A.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
f(a, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy()[::-1][:n])
check_llvm(4, 0)
check_llvm(128, 8)
check_llvm(3, 0)
check_llvm(128, 1)
def test_llvm_madd_pipeline():
def check_llvm(nn, base, stride):
if not tvm.codegen.enabled("llvm"):
return
n = tvm.convert(nn)
A = tvm.placeholder((n + base, stride), name='A')
C = tvm.compute((n, stride), lambda i, j: A(base + i, j) + 1, name='C')
s = tvm.Schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=4)
s[C].parallel(xo)
s[C].vectorize(xi)
# build and invoke the kernel.
f = tvm.build(s, [A, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
n = nn
a = tvm.nd.array(np.random.uniform(size=(n + base, stride)).astype(A.dtype), ctx)
c = tvm.nd.array(np.zeros((n, stride), dtype=C.dtype), ctx)
f(a, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy()[base:] + 1)
check_llvm(64, 0, 2)
check_llvm(4, 0, 1)
check_llvm(4, 0, 3)
if __name__ == "__main__":
test_llvm_add_pipeline()
test_llvm_flip_pipeline()
test_llvm_madd_pipeline()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment