Commit 0a19b16a by Tianqi Chen Committed by GitHub

[CODEGEN/PASS] add restricted, alignment option (#221)

* [CODEGEN/PASS] add restricted, alignment option

* fix lint

* Fix the alloca
parent 00506a62
......@@ -233,6 +233,9 @@ Stmt LoopPartition(Stmt stmt);
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_unpacked_args Number of arguments that
* are processed in plain form instead of packed form.
* \param is_restricted Whether the caller can guarantee that each buffer argument do not overlap.
* It is recommended to set to true for optimized code if such invariant holds.
*
* \return a LoweredFunc with the specified signiture.
*
* \note
......@@ -254,7 +257,8 @@ Stmt LoopPartition(Stmt stmt);
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_unpacked_args);
int num_unpacked_args,
bool is_restricted);
/*!
* \brief Find undefined vars in the statment.
......
......@@ -86,6 +86,11 @@ class LoweredFuncNode : public FunctionBaseNode {
LoweredFuncType func_type{kMixedFunc};
/*! \brief Whether this function is packed function */
bool is_packed_func{true};
/*!
* \brief Whether function ensures that argument pointers do not alias.
* This corresponds to restrict keyword in C.
*/
bool is_restricted{true};
/*! \brief The body statment of the function */
Stmt body;
/*! \return name of the operation */
......@@ -104,6 +109,7 @@ class LoweredFuncNode : public FunctionBaseNode {
v->Visit("handle_data_type", &handle_data_type);
v->Visit("func_type", &func_type);
v->Visit("is_packed_func", &is_packed_func);
v->Visit("is_restricted", &is_restricted);
v->Visit("body", &body);
}
......
......@@ -24,6 +24,9 @@ enum DeviceAttrKind : int {
/*! \brief Number of bytes each allocation must align to */
constexpr int kAllocAlignment = 64;
/*! \brief Number of bytes each allocation must align to in temporary allocation */
constexpr int kTempAllocaAlignment = 64;
/*!
* \brief TVM Runtime Device API, abstracts the device
* specific interface for memory management.
......
......@@ -27,7 +27,9 @@ class BuildConfig(object):
'auto_unroll_min_depth': 1,
'unroll_explicit': True,
'detect_global_barrier': False,
'offset_factor': 0
'offset_factor': 0,
'data_alignment': 0,
'restricted_func': True
}
def __init__(self, **kwargs):
self._old_scope = None
......@@ -77,10 +79,20 @@ def build_config(**kwargs):
detect_global_barrier: bool, default=True
Whether detect global barrier.
data_alignment: int, optional
The alignment of data pointer in bytes.
If 0 is passed, the alignment will be set to TVM's internal default.
offset_factor: int, default=0
The factor used in default buffer declaration.
If specified as 0, offset field is not used.
restricted_func: bool, default=True
Whether build restricted function.
That is each buffer argument to the function are guaranteed
not to overlap. This enables more optimization.
Corresponds to restricted keyword in C99
Returns
-------
config: BuildConfig
......@@ -110,12 +122,15 @@ def get_binds(args, binds=None):
The list of symbolic buffers of arguments.
"""
binds = {} if binds is None else binds.copy()
offset_factor = BuildConfig.current.offset_factor
cfg = BuildConfig.current
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
buf = api.decl_buffer(x.shape, dtype=x.dtype, name=x.name,
offset_factor=offset_factor)
buf = api.decl_buffer(x.shape,
dtype=x.dtype,
name=x.name,
data_alignment=cfg.data_alignment,
offset_factor=cfg.offset_factor)
assert x not in binds
binds[x] = buf
arg_list.append(buf)
......@@ -181,7 +196,7 @@ def lower(sch,
stmt = ir_pass.Simplify(stmt)
if simple_mode:
return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0)
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
def build(sch,
......
......@@ -68,6 +68,12 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
*ret = PassName(args[0], args[1], args[2], args[3]); \
}) \
#define REGISTER_PASS5(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1], args[2], args[3], args[4]); \
}) \
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(CanonicalSimplify);
......@@ -76,7 +82,7 @@ REGISTER_PASS2(StorageFlatten);
REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS4(UnrollLoop);
REGISTER_PASS2(StorageSync);
REGISTER_PASS4(MakeAPI);
REGISTER_PASS5(MakeAPI);
REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(StorageRewrite);
REGISTER_PASS1(InjectVirtualThread);
......
......@@ -46,6 +46,9 @@ void CodeGenC::AddFunction(LoweredFunc f) {
if (handle_data_type_.count(v.get())) {
PrintType(handle_data_type_.at(v.get()), stream);
stream << "*";
if (f->is_restricted && restrict_keyword_.length() != 0) {
stream << ' ' << restrict_keyword_;
}
} else {
PrintType(v.type(), stream);
}
......
......@@ -167,6 +167,8 @@ class CodeGenC :
// override
void PrintSSAAssign(
const std::string& target, const std::string& src, Type t) final;
/*! \brief restrict keyword */
std::string restrict_keyword_{""};
/*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
/*! \brief the data type of allocated buffers */
......
......@@ -14,6 +14,10 @@
namespace tvm {
namespace codegen {
CodeGenCUDA::CodeGenCUDA() {
restrict_keyword_ = "__restrict__";
}
void CodeGenCUDA::Init(bool output_ssa) {
CodeGenC::Init(output_ssa);
vid_global_barrier_state_ = GetUniqueName(runtime::symbol::tvm_global_barrier_state);
......
......@@ -16,6 +16,7 @@ namespace codegen {
class CodeGenCUDA final : public CodeGenC {
public:
CodeGenCUDA();
void Init(bool output_ssa);
void AddFunction(LoweredFunc f);
// override behavior
......
......@@ -12,6 +12,10 @@
namespace tvm {
namespace codegen {
CodeGenOpenCL::CodeGenOpenCL() {
restrict_keyword_ = "restrict";
}
void CodeGenOpenCL::InitFuncState(LoweredFunc f) {
CodeGenC::InitFuncState(f);
for (Var arg : f->args) {
......
......@@ -16,6 +16,7 @@ namespace codegen {
class CodeGenOpenCL final : public CodeGenC {
public:
CodeGenOpenCL();
void AddFunction(LoweredFunc f);
// override print thread tag.
void InitFuncState(LoweredFunc f) final;
......
......@@ -5,6 +5,7 @@
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/device_api.h>
#include <tvm/ir_pass.h>
#include "./codegen_llvm.h"
#include "../../pass/ir_util.h"
......@@ -65,6 +66,8 @@ void CodeGenLLVM::Init(const std::string& module_name,
md_very_likely_branch_ =
md_builder_->createBranchWeights(1 << 30, 0);
md_tbaa_root_ = md_builder_->createTBAARoot("tvmtbaa");
md_tbaa_alias_set_ = md_builder_->createTBAAScalarTypeNode(
"alias_set", md_tbaa_root_);
}
ctx_ = ctx;
// initialize modules
......@@ -131,10 +134,12 @@ void CodeGenLLVM::InitFuncState() {
var_map_.clear();
align_map_.clear();
alloc_storage_info_.clear();
alias_var_set_.clear();
}
void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
this->InitFuncState();
is_restricted_ = f->is_restricted;
CHECK(!module_->getFunction(f->name))
<< "Function " << f->name << "already exists in module";
std::vector<llvm::Type*> arg_type;
......@@ -143,6 +148,9 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
if (t.is_handle() && f->handle_data_type.count(arg)) {
arg_type.push_back(
LLVMType(f->handle_data_type[arg].type())->getPointerTo());
if (!is_restricted_) {
alias_var_set_.insert(arg.get());
}
} else {
arg_type.push_back(LLVMType(t));
}
......@@ -265,6 +273,14 @@ llvm::BasicBlock* CodeGenLLVM::CheckCallSuccess(llvm::Value* retcode) {
void CodeGenLLVM::AddAliasInfo(
llvm::Instruction* inst, const Variable* buffer, Expr index, Type t) {
if (alias_var_set_.count(buffer) != 0) {
// Mark all possibly aliased pointer as same type.
llvm::MDNode* meta = md_tbaa_alias_set_;
inst->setMetadata(
"tbaa",
md_builder_->createTBAAStructTagNode(meta, meta, 0));
return;
}
int base = 0, width = 0;
// create meta-data for alias analysis
// Use a group of binary tree ranges.
......@@ -1324,10 +1340,10 @@ void CodeGenLLVM::VisitStmt_(const Allocate* op) {
LLVMType(op->type), ConstInt32(constant_size));
buf = alloca;
StorageInfo& info = alloc_storage_info_[op->buffer_var.get()];
// Align stack to be multiple of 4 if it is
// Align stack to be TempAllocaAlignment.
// TODO(tqchen) have pass to detect vector access and pre-set alignment
if (constant_size % 4 == 0 && info.alignment == 0) {
info.alignment = op->type.bytes() * 4;
info.alignment = GetTempAllocaAlignment(op->type, constant_size);
}
if (alloca->getAlignment() < static_cast<uint32_t>(info.alignment)) {
alloca->setAlignment(info.alignment);
......@@ -1408,6 +1424,11 @@ void CodeGenLLVM::VisitStmt_(const LetStmt* op) {
llvm::Value* v = MakeValue(op->value);
CHECK(!var_map_.count(op->var.get()));
CHECK(!align_map_.count(op->var.get()));
if (op->var.type().is_handle()) {
if (!is_restricted_) {
alias_var_set_.insert(op->var.get());
}
}
var_map_[op->var.get()] = v;
align_map_[op->var.get()] = arith::EvalModular(op->value, align_map_);
this->VisitStmt(op->body);
......
......@@ -174,6 +174,7 @@ class CodeGenLLVM :
// branch
llvm::MDNode* md_very_likely_branch_{nullptr};
llvm::MDNode* md_tbaa_root_{nullptr};
llvm::MDNode* md_tbaa_alias_set_{nullptr};
// TVM related data types
llvm::Type* t_tvm_shape_index_{nullptr};
llvm::Type* t_tvm_func_handle_{nullptr};
......@@ -234,6 +235,10 @@ class CodeGenLLVM :
std::unordered_map<std::string, llvm::Constant*> str_map_;
// The alignment information
std::unordered_map<const Variable*, arith::ModularEntry> align_map_;
// Whether current function is restricted
bool is_restricted_{true};
// set of var that are not restricted(can alias)
std::unordered_set<const Variable*> alias_var_set_;
// The local module_context
llvm::GlobalVariable* gv_mod_ctx_{nullptr};
// global to packed function handle
......
......@@ -100,7 +100,7 @@ Buffer Buffer::MakeStrideView() const {
Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
const BufferNode* n = operator->();
Expr elem_offset = ElemOffset(n, begins);
Expr elem_offset = ir::Simplify(ElemOffset(n, begins));
Array<Expr> strides = n->strides;
if (strides.size() == 0) {
bool can_relax = true;
......@@ -146,6 +146,9 @@ Buffer BufferNode::make(Var data,
n->shape = std::move(shape);
n->strides = std::move(strides);
n->name = std::move(name);
if (scope.length() == 0) {
scope = "global";
}
n->scope = std::move(scope);
if (!elem_offset.defined()) {
elem_offset = make_const(n->shape[0].type(), 0);
......@@ -156,7 +159,7 @@ Buffer BufferNode::make(Var data,
if (offset_factor == 0) {
offset_factor = 1;
}
n->elem_offset = elem_offset;
n->elem_offset = std::move(elem_offset);
n->data_alignment = data_alignment;
n->offset_factor = offset_factor;
return Buffer(n);
......
......@@ -80,41 +80,64 @@ void ArgBinder::BindBuffer(const Buffer& arg,
CHECK_EQ(arg->scope, value->scope)
<< "Argument " << arg_name
<< " Buffer bind scope mismatch";
CHECK_EQ(arg->dtype, value->dtype)
<< "Argument " << arg_name
<< " Buffer bind data type mismatch";
if (value->data_alignment % arg->data_alignment != 0) {
LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement "
<< " required_alignment=" << arg->data_alignment
<< ", provided_alignment=" << value->data_alignment;
}
// bind pointer and offset.
if (is_zero(arg->elem_offset) && !is_zero(value->elem_offset)) {
// If target buffer only accepts pointer, try to fold offset into pointer.
Expr addr = AddressOffset(value->data, value->dtype, value->elem_offset);
if (Bind_(arg->data, addr, arg_name + ".data", true)) {
int offset_factor = arg->data_alignment * 8 / (arg->dtype.bits() * arg->dtype.lanes());
if (offset_factor > 1) {
Expr offset = value->elem_offset;
Expr factor = make_const(offset.type(), offset_factor);
Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
}
}
} else {
this->Bind(arg->data, value->data, arg_name + ".data");
if (arg->shape.size() > value->shape.size()) {
if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
if (arg->offset_factor > 1) {
Expr offset = value->elem_offset;
Expr factor = make_const(offset.type(), arg->offset_factor);
Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
}
}
}
if (arg->shape.size() < value->shape.size()) {
CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch";
size_t diff = arg->shape.size() - value->shape.size();
size_t diff = value->shape.size() - arg->shape.size();
for (size_t i = 0; i < diff; ++i) {
CHECK(is_one(arg->shape[i]))
CHECK(is_one(value->shape[i]))
<< "Argument " << arg_name << " shape mismatch"
<< arg->shape << " vs " << value->shape;
}
for (size_t i = 0; i < value->shape.size(); ++i) {
for (size_t i = 0; i < arg->shape.size(); ++i) {
std::ostringstream os;
os << arg_name << ".shape[" << i << "]";
this->Bind(arg->shape[i + diff], value->shape[i], os.str());
this->Bind(arg->shape[i], value->shape[i + diff], os.str());
}
if (arg->strides.size() != 0) {
if (value->strides.size() != 0) {
CHECK_EQ(arg->strides.size(), arg->shape.size());
CHECK_EQ(value->strides.size(), value->shape.size());
for (size_t i = 0; i < value->strides.size(); ++i) {
for (size_t i = 0; i < arg->strides.size(); ++i) {
std::ostringstream os;
os << arg_name << ".strides[" << i << "]";
this->Bind(arg->strides[i + diff], value->strides[i], os.str());
this->Bind(arg->strides[i], value->strides[i + diff], os.str());
}
}
} else {
this->BindArray(arg->shape, value->shape, arg_name + ".shape");
this->BindArray(arg->strides, value->strides, arg_name + ".strides");
}
if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
if (arg->offset_factor > 1) {
Expr offset = value->elem_offset;
Expr factor = make_const(offset.type(), arg->offset_factor);
Expr zero = make_zero(offset.type());
BinderAddAssert(offset % factor == zero, arg_name + ".elem_offset", &asserts_);
}
}
}
inline Expr TVMArrayGet(Type t, Var arr, intrinsic::TVMStructFieldKind kind) {
......
......@@ -7,6 +7,7 @@
#define TVM_PASS_IR_UTIL_H_
#include <tvm/ir.h>
#include <tvm/runtime/device_api.h>
#include <vector>
namespace tvm {
......@@ -93,6 +94,20 @@ inline Expr AddressOffset(Var handle, Type dtype, int offset) {
}
/*!
* \brief Address of handle + offset
* \param handle the array handle.
* \param dtype The data type.
* \param offset the offset index.
*/
inline Expr AddressOffset(Var handle, Type dtype, Expr offset) {
return Call::make(
Handle(), intrinsic::tvm_address_of,
{Load::make(dtype, handle, offset * make_const(offset.type(), dtype.lanes()),
const_true(dtype.lanes()))},
Call::PureIntrinsic);
}
/*!
* \brief Set value into struct.
* \param handle the struct handle.
* \param index the offset index.
......@@ -125,6 +140,23 @@ inline Type APIType(Type t) {
CHECK(t.is_float());
return Float(64);
}
/*!
* \brief Rule to get allocation alignment requirement for a given const array.
* \param type The type of allocation.
* \param const_size The constant size of the array.
* \return the alignment
*/
inline int GetTempAllocaAlignment(Type type, int32_t const_size) {
int align = runtime::kTempAllocaAlignment;
if (const_size > 0) {
const_size = const_size * type.bits() * type.lanes() / 8;
while (align > const_size) {
align = align / 2;
}
}
return align;
}
} // namespace ir
} // namespace tvm
#endif // TVM_PASS_IR_UTIL_H_
......@@ -25,7 +25,8 @@ inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) {
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_unpacked_args) {
int num_unpacked_args,
bool is_restricted) {
const Stmt nop = Evaluate::make(0);
int num_args = static_cast<int>(api_args.size());
CHECK_LE(num_unpacked_args, num_args);
......@@ -132,6 +133,7 @@ LoweredFunc MakeAPI(Stmt body,
n->args = args;
n->handle_data_type = binder.def_handle_dtype();
n->is_packed_func = num_unpacked_args == 0;
n->is_restricted = is_restricted;
// Set device context
if (vmap.count(device_id.get())) {
......
......@@ -7,6 +7,7 @@
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include <tvm/buffer.h>
#include <tvm/runtime/device_api.h>
#include <unordered_map>
#include "./ir_util.h"
#include "./arg_binder.h"
......@@ -88,11 +89,6 @@ class StorageFlattener : public IRMutator {
for (auto r : e.bounds) {
shape.push_back(r->extent);
}
e.buffer = decl_buffer(shape, op->type, key.GetName());
buf_map_[key] = e;
Stmt body = this->Mutate(op->body);
buf_map_[key].released = true;
// deduce current storage scope.
auto it = storage_scope_.find(op->func.get());
CHECK(it != storage_scope_.end())
......@@ -107,12 +103,26 @@ class StorageFlattener : public IRMutator {
} else {
skey = StorageScope::make(strkey);
}
// use small alignment for small arrays
int32_t const_size = Allocate::constant_allocation_size(shape, key.GetName());
int align = GetTempAllocaAlignment(op->type, const_size);
e.buffer = BufferNode::make(
Var(key.GetName(), Handle()),
op->type, shape,
Array<Expr>(), Expr(),
key.GetName(), skey.to_string(),
align, 0);
buf_map_[key] = e;
Stmt body = this->Mutate(op->body);
buf_map_[key].released = true;
Stmt ret = Allocate::make(
e.buffer->data, e.buffer->dtype, e.buffer->shape,
make_const(Bool(e.buffer->dtype.lanes()), true), body);
ret = AttrStmt::make(
e.buffer->data, attr::storage_scope,
StringImm::make(skey.to_string()), ret);
StringImm::make(e.buffer->scope), ret);
return ret;
}
}
......
......@@ -93,6 +93,14 @@ class StorageAccessPatternFinder final : public IRVisitor {
AccessEntry(buf, op->index, kRead, GetScope(buf)));
}
}
void Visit_(const Call* op) final {
if (op->is_intrinsic(intrinsic::tvm_address_of)) {
const Load* l = op->args[0].as<Load>();
this->Visit(l->index);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Variable* buf) final {
// Directly reference to the variable count as a read.
auto it = alloc_scope_level_.find(buf);
......
......@@ -16,7 +16,7 @@ def lower(s, args, name="mydot"):
stmt = tvm.ir_pass.StorageFlatten(stmt, binds)
stmt = tvm.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, name, arg_list, 0)
fapi = tvm.ir_pass.MakeAPI(stmt, name, arg_list, 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
return fapi
......
......@@ -25,7 +25,7 @@ def test_add_pipeline():
stmt = tvm.ir_pass.LoopPartition(stmt)
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0, True)
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(fapi)]
fsplits[0] = tvm.ir_pass.LowerPackedCall(fsplits[0])
......
......@@ -6,7 +6,8 @@ def test_llvm_add_pipeline():
n = tvm.convert(nn)
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
T = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='T')
C = tvm.compute(A.shape, lambda *i: T(*i), name='C')
s = tvm.create_schedule(C.op)
xo, xi = s[C].split(C.op.axis[0], factor=4)
s[C].parallel(xo)
......@@ -87,6 +88,7 @@ def test_llvm_madd_pipeline():
c.asnumpy(), a.asnumpy()[base:] + 1)
check_llvm(64, 0, 2)
check_llvm(4, 0, 1)
with tvm.build_config(restricted_func=False):
check_llvm(4, 0, 3)
def test_llvm_temp_space():
......
......@@ -19,7 +19,7 @@ def test_stack_vm_basic():
n = tvm.var('n')
Ab = tvm.decl_buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0)
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
run_jit(fapi, lambda f: f(a))
......@@ -41,7 +41,7 @@ def test_stack_vm_loop():
ib.emit(tvm.call_packed("tvm_stack_vm_print", i))
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0)
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f):
......@@ -64,7 +64,7 @@ def test_stack_vm_cond():
A[i + 1] = A[i] + 2
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 0)
fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
......
......@@ -37,7 +37,7 @@ def test_dso_module_load():
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1,
i + 1))
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0)
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
m = tvm.codegen.build_module(fapi, "llvm")
for name in names:
......
......@@ -19,7 +19,7 @@ def test_makeapi():
num_unpacked_args = 2
f = tvm.ir_pass.MakeAPI(
stmt, "myadd", [n, Ab, Bb, Cb], num_unpacked_args)
stmt, "myadd", [n, Ab, Bb, Cb], num_unpacked_args, True)
assert(f.handle_data_type[Ab.data].dtype == Ab.dtype)
assert(len(f.args) == 5)
output_ssa = False
......
......@@ -20,7 +20,7 @@ def test_storage_sync():
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
f = tvm.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0)
f = tvm.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
flist = tvm.ir_pass.SplitHostDevice(f)
f = flist[1]
f = tvm.ir_pass.StorageSync(f, "shared")
......
......@@ -24,7 +24,7 @@ def test_dltensor_compatible():
with ib.for_range(0, n - 1, "i") as i:
A[i + 1] = A[i] + 1
stmt = ib.get()
fapi = tvm.ir_pass.MakeAPI(stmt, "arange", [Ab], 0)
fapi = tvm.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True)
fapi = tvm.ir_pass.LowerPackedCall(fapi)
f = tvm.codegen.build_module(fapi, "stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment