Unverified Commit f08d5d78 by Tianqi Chen Committed by GitHub

[TIR] Refactor MakePackedAPI to target dependent stage. (#5326)

Previously MakePackedAPI was in the target independent stage,
but never the less requires the device_type information that will be
binded at a later target dependent stage.

The previous implementation was due to the limitation of LoweredFunc
which can not carry buffer_map info(so they have to be lowered right away).
This is no longer the case after the unified IR refactor.

This PR migrates MakePackedAPI to a target dependent stage
and removes the un-necessary BindDevice pass.
parent 4720cf85
...@@ -352,12 +352,19 @@ class Sequential : public Pass { ...@@ -352,12 +352,19 @@ class Sequential : public Pass {
* *
* \return The created module pass. * \return The created module pass.
*/ */
Pass CreateModulePass( TVM_DLL Pass CreateModulePass(
const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func, const runtime::TypedPackedFunc<IRModule(IRModule, PassContext)>& pass_func,
int opt_level, int opt_level,
const std::string& name, const std::string& name,
const Array<runtime::String>& required); const Array<runtime::String>& required);
/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
* \return The pass.
*/
TVM_DLL Pass PrintIR(std::string header);
} // namespace transform } // namespace transform
} // namespace tvm } // namespace tvm
......
...@@ -193,6 +193,15 @@ class TVM_DLL DeviceAPI { ...@@ -193,6 +193,15 @@ class TVM_DLL DeviceAPI {
* \return The corresponding device API. * \return The corresponding device API.
*/ */
static DeviceAPI* Get(TVMContext ctx, bool allow_missing = false); static DeviceAPI* Get(TVMContext ctx, bool allow_missing = false);
/*!
* \brief Whether a certian device type requires set device context
* before launching the kernel function.
* \param device_type The device type.
*/
static bool NeedSetDeviceContext(int device_type) {
return device_type != kDLCPU && device_type != kDLMicroDev;
}
}; };
/*! \brief The device type bigger than this is RPC device */ /*! \brief The device type bigger than this is RPC device */
......
...@@ -112,15 +112,6 @@ TVM_DLL Pass RemapThreadAxis(Map<runtime::String, IterVar> axis_map); ...@@ -112,15 +112,6 @@ TVM_DLL Pass RemapThreadAxis(Map<runtime::String, IterVar> axis_map);
*/ */
TVM_DLL Pass LowerCustomDatatypes(); TVM_DLL Pass LowerCustomDatatypes();
/*!
* \brief Bind the device type ofthe function to be
* the device_type specified in the target attribute.
*
* \return The pass.
*/
TVM_DLL Pass BindDeviceType();
/*! /*!
* \brief Split the function into a host function and device functions. * \brief Split the function into a host function and device functions.
* *
......
...@@ -200,7 +200,7 @@ def lower(sch, ...@@ -200,7 +200,7 @@ def lower(sch,
if cfg.restricted_func: if cfg.restricted_func:
f = f.with_attr("tir.noalias", True) f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: f}) mod = tvm.IRModule({name: f})
return tvm.tir.transform.MakePackedAPI()(mod) return mod
def _build_for_device(input_mod, target, target_host): def _build_for_device(input_mod, target, target_host):
...@@ -243,13 +243,13 @@ def _build_for_device(input_mod, target, target_host): ...@@ -243,13 +243,13 @@ def _build_for_device(input_mod, target, target_host):
tvm.tir.transform.ThreadSync("warp"), tvm.tir.transform.ThreadSync("warp"),
tvm.tir.transform.InferFragment(), tvm.tir.transform.InferFragment(),
tvm.tir.transform.LowerThreadAllreduce(), tvm.tir.transform.LowerThreadAllreduce(),
tvm.tir.transform.BindDeviceType(), tvm.tir.transform.MakePackedAPI(),
tvm.tir.transform.SplitHostDevice()] tvm.tir.transform.SplitHostDevice()]
mod_mixed = tvm.ir.transform.Sequential(opt_mixed)(mod_mixed) mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)
# device optimizations # device optimizations
opt_device = tvm.ir.transform.Sequential( opt_device = tvm.transform.Sequential(
[tvm.tir.transform.Filter( [tvm.tir.transform.Filter(
lambda f: "calling_conv" in f.attrs and lambda f: "calling_conv" in f.attrs and
f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH), f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH),
...@@ -259,7 +259,7 @@ def _build_for_device(input_mod, target, target_host): ...@@ -259,7 +259,7 @@ def _build_for_device(input_mod, target, target_host):
mod_dev = opt_device(mod_mixed) mod_dev = opt_device(mod_mixed)
# host optimizations # host optimizations
opt_host = tvm.ir.transform.Sequential( opt_host = tvm.transform.Sequential(
[tvm.tir.transform.Filter( [tvm.tir.transform.Filter(
lambda f: "calling_conv" not in f.attrs or lambda f: "calling_conv" not in f.attrs or
f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH), f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH),
......
...@@ -22,13 +22,13 @@ import functools ...@@ -22,13 +22,13 @@ import functools
import tvm._ffi import tvm._ffi
from tvm._ffi.runtime_ctypes import TVMContext import tvm.runtime
from tvm.runtime import Object, ndarray as _nd from tvm.runtime import ndarray as _nd
from . import _ffi_transform_api from . import _ffi_transform_api
@tvm._ffi.register_object("transform.PassInfo") @tvm._ffi.register_object("transform.PassInfo")
class PassInfo(Object): class PassInfo(tvm.runtime.Object):
"""The class contains the meta data required by a pass. It is the """The class contains the meta data required by a pass. It is the
container of information needed by running an optimization or analysis. container of information needed by running an optimization or analysis.
This class can be extended by adding new members when more meta data is This class can be extended by adding new members when more meta data is
...@@ -52,7 +52,7 @@ class PassInfo(Object): ...@@ -52,7 +52,7 @@ class PassInfo(Object):
@tvm._ffi.register_object("transform.PassContext") @tvm._ffi.register_object("transform.PassContext")
class PassContext(Object): class PassContext(tvm.runtime.Object):
"""The basis where a Relay optimization/analysis runs on. """The basis where a Relay optimization/analysis runs on.
Each pass context contains a number of auxiliary information that is used Each pass context contains a number of auxiliary information that is used
to help an optimization pass. Such information includes the error reporter to help an optimization pass. Such information includes the error reporter
...@@ -79,7 +79,7 @@ class PassContext(Object): ...@@ -79,7 +79,7 @@ class PassContext(Object):
trace=None): trace=None):
if isinstance(fallback_device, str): if isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device).device_type fallback_device = _nd.context(fallback_device).device_type
elif isinstance(fallback_device, TVMContext): elif isinstance(fallback_device, tvm.runtime.TVMContext):
fallback_device = fallback_device.device_type fallback_device = fallback_device.device_type
if not isinstance(fallback_device, int): if not isinstance(fallback_device, int):
raise TypeError("fallback_device is expected to be the type of " + raise TypeError("fallback_device is expected to be the type of " +
...@@ -113,7 +113,7 @@ class PassContext(Object): ...@@ -113,7 +113,7 @@ class PassContext(Object):
@tvm._ffi.register_object("transform.Pass") @tvm._ffi.register_object("transform.Pass")
class Pass(Object): class Pass(tvm.runtime.Object):
"""The base class of all passes. All methods here are just simple wrappers """The base class of all passes. All methods here are just simple wrappers
that are implemented in the backend. They are defined for users to that are implemented in the backend. They are defined for users to
conveniently interact with the base class. conveniently interact with the base class.
...@@ -327,3 +327,18 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None): ...@@ -327,3 +327,18 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
if pass_func: if pass_func:
return create_module_pass(pass_func) return create_module_pass(pass_func)
return create_module_pass return create_module_pass
def PrintIR(header):
"""A special trace pass that prints the header and IR.
Parameters
----------
header : str
The header to be displayed along with the dump.
Returns
--------
The pass
"""
return _ffi_transform_api.PrintIR(header)
...@@ -195,13 +195,14 @@ def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias): ...@@ -195,13 +195,14 @@ def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
mod : IRModule mod : IRModule
The created IRModule. The created IRModule.
""" """
assert num_unpacked_args == 0
f = tvm.tir.PrimFunc(args, stmt).with_attr( f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name)) "global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True) f = f.with_attr("tir.is_entry_func", True)
if noalias: if noalias:
f = f.with_attr("tir.noalias", True) f = f.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: f}) mod = tvm.IRModule({name: f})
return tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod) return mod
tvm._ffi._init_api("testing", __name__) tvm._ffi._init_api("testing", __name__)
...@@ -32,7 +32,7 @@ def Apply(ftransform): ...@@ -32,7 +32,7 @@ def Apply(ftransform):
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
""" """
# pylint: disable=unused-argument # pylint: disable=unused-argument
...@@ -51,7 +51,7 @@ def Filter(fcond): ...@@ -51,7 +51,7 @@ def Filter(fcond):
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
""" """
# pylint: disable=unused-argument # pylint: disable=unused-argument
...@@ -67,7 +67,7 @@ def LowerCustomDatatypes(): ...@@ -67,7 +67,7 @@ def LowerCustomDatatypes():
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
""" """
return _ffi_api.LowerCustomDatatypes() return _ffi_api.LowerCustomDatatypes()
...@@ -84,30 +84,18 @@ def MakePackedAPI(num_unpacked_params=0): ...@@ -84,30 +84,18 @@ def MakePackedAPI(num_unpacked_params=0):
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
""" """
return _ffi_api.MakePackedAPI(num_unpacked_params) return _ffi_api.MakePackedAPI(num_unpacked_params)
def BindDeviceType():
"""Bind the device type of the function to be
the device_type specified in the target attribute.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.BindDeviceType()
def SplitHostDevice(): def SplitHostDevice():
"""Split the function into a host function and device functions. """Split the function into a host function and device functions.
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
""" """
return _ffi_api.SplitHostDevice() return _ffi_api.SplitHostDevice()
...@@ -118,7 +106,7 @@ def SkipAssert(): ...@@ -118,7 +106,7 @@ def SkipAssert():
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
""" """
return _ffi_api.SkipAssert() return _ffi_api.SkipAssert()
...@@ -134,7 +122,7 @@ def ThreadSync(storage_scope): ...@@ -134,7 +122,7 @@ def ThreadSync(storage_scope):
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
""" """
return _ffi_api.ThreadSync(storage_scope) return _ffi_api.ThreadSync(storage_scope)
...@@ -145,7 +133,7 @@ def LowerThreadAllreduce(): ...@@ -145,7 +133,7 @@ def LowerThreadAllreduce():
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
""" """
return _ffi_api.LowerThreadAllreduce() return _ffi_api.LowerThreadAllreduce()
...@@ -156,7 +144,7 @@ def InferFragment(): ...@@ -156,7 +144,7 @@ def InferFragment():
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
""" """
return _ffi_api.InferFragment() return _ffi_api.InferFragment()
...@@ -167,7 +155,7 @@ def LowerWarpMemory(): ...@@ -167,7 +155,7 @@ def LowerWarpMemory():
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
""" """
return _ffi_api.LowerWarpMemory() return _ffi_api.LowerWarpMemory()
...@@ -178,7 +166,7 @@ def LowerTVMBuiltin(): ...@@ -178,7 +166,7 @@ def LowerTVMBuiltin():
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
""" """
return _ffi_api.LowerTVMBuiltin() return _ffi_api.LowerTVMBuiltin()
...@@ -189,7 +177,7 @@ def LowerIntrin(): ...@@ -189,7 +177,7 @@ def LowerIntrin():
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
""" """
return _ffi_api.LowerIntrin() return _ffi_api.LowerIntrin()
...@@ -200,7 +188,7 @@ def LowerDeviceStorageAccessInfo(): ...@@ -200,7 +188,7 @@ def LowerDeviceStorageAccessInfo():
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
Note Note
...@@ -215,7 +203,7 @@ def CombineContextCall(): ...@@ -215,7 +203,7 @@ def CombineContextCall():
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
""" """
return _ffi_api.CombineContextCall() return _ffi_api.CombineContextCall()
...@@ -231,7 +219,7 @@ def NarrowDataType(target_bits): ...@@ -231,7 +219,7 @@ def NarrowDataType(target_bits):
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
Note Note
......
...@@ -216,8 +216,7 @@ IRModule lower(te::Schedule sch, ...@@ -216,8 +216,7 @@ IRModule lower(te::Schedule sch,
if (config->restricted_func) { if (config->restricted_func) {
f = WithAttr(std::move(f), "tir.noalias", Integer(1)); f = WithAttr(std::move(f), "tir.noalias", Integer(1));
} }
auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}})); return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
return tir::transform::MakePackedAPI(0)(mod);
} }
...@@ -237,7 +236,7 @@ split_dev_host_funcs(IRModule mod_mixed, ...@@ -237,7 +236,7 @@ split_dev_host_funcs(IRModule mod_mixed,
mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
mixed_pass_list.push_back(tir::transform::InferFragment()); mixed_pass_list.push_back(tir::transform::InferFragment());
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
mixed_pass_list.push_back(tir::transform::BindDeviceType()); mixed_pass_list.push_back(tir::transform::MakePackedAPI(0));
mixed_pass_list.push_back(tir::transform::SplitHostDevice()); mixed_pass_list.push_back(tir::transform::SplitHostDevice());
auto opt_mixed = transform::Sequential(mixed_pass_list); auto opt_mixed = transform::Sequential(mixed_pass_list);
mod_mixed = opt_mixed(std::move(mod_mixed)); mod_mixed = opt_mixed(std::move(mod_mixed));
......
...@@ -473,5 +473,18 @@ TVM_REGISTER_GLOBAL("transform.EnterPassContext") ...@@ -473,5 +473,18 @@ TVM_REGISTER_GLOBAL("transform.EnterPassContext")
TVM_REGISTER_GLOBAL("transform.ExitPassContext") TVM_REGISTER_GLOBAL("transform.ExitPassContext")
.set_body_typed(PassContext::Internal::ExitScope); .set_body_typed(PassContext::Internal::ExitScope);
Pass PrintIR(std::string header) {
auto pass_func =[header](IRModule mod, const PassContext& ctx) {
LOG(INFO) << "PrintIR(" << header << "):\n"
<< mod;
return mod;
};
return CreateModulePass(pass_func, 0, "PrintIR", {});
}
TVM_REGISTER_GLOBAL("transform.PrintIR")
.set_body_typed(PrintIR);
} // namespace transform } // namespace transform
} // namespace tvm } // namespace tvm
...@@ -58,6 +58,8 @@ StackVM::StructFieldKind MapFieldKind(int64_t kind) { ...@@ -58,6 +58,8 @@ StackVM::StructFieldKind MapFieldKind(int64_t kind) {
} }
StackVM CodeGenStackVM::Compile(const PrimFunc& f) { StackVM CodeGenStackVM::Compile(const PrimFunc& f) {
CHECK_EQ(f->buffer_map.size(), 0U)
<< "Cannot codegen function with buffer_map, please lower them first";
for (size_t i = 0; i < f->params.size(); ++i) { for (size_t i = 0; i < f->params.size(); ++i) {
Var v = f->params[i]; Var v = f->params[i];
int vid = AllocVarID(v.get()); int vid = AllocVarID(v.get());
......
...@@ -114,9 +114,11 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { ...@@ -114,9 +114,11 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
/// Check if the value of a Variable comes from function argument. /// Check if the value of a Variable comes from function argument.
bool IsFromFunctionArgs(const VarNode *var) const { bool IsFromFunctionArgs(const VarNode *var) const {
const VarNode *V = var; const VarNode *V = var;
while (true) { for (auto kv : func_->buffer_map) {
CHECK(V) << "Invalid Variable\n"; if (V == kv.second->data.get()) return true;
}
while (true) {
// Variable is from function args. Return true. // Variable is from function args. Return true.
if (V == func_->params[0].get()) return true; if (V == func_->params[0].get()) return true;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file bind_device_type.cc
* \brief Bind the device type according to the target field.
*/
#include <tvm/ir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/analysis.h>
#include <tvm/target/target.h>
#include <tvm/runtime/registry.h>
namespace tvm {
namespace tir {
class DeviceTypeBinder: public StmtExprMutator {
public:
explicit DeviceTypeBinder(int device_type)
: device_type_(device_type) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::device_context_type) {
if (const VarNode* var = op->value.as<VarNode>()) {
var_ = var;
PrimExpr value = make_const(op->value.dtype(), device_type_);
Stmt body = StmtExprMutator::VisitStmt_(op);
var_ = nullptr;
std::ostringstream os;
os << "device_type need to be " << device_type_;
return AssertStmtNode::make(op->value == value, tvm::tir::StringImmNode::make(os.str()),
body);
}
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const IfThenElseNode* op) final {
// eager simplify if guard.
Stmt res = StmtExprMutator::VisitStmt_(op);
op = res.as<IfThenElseNode>();
if (is_zero(op->condition)) {
if (op->else_case.defined()) return op->else_case;
return EvaluateNode::make(0);
}
if (is_one(op->condition)) {
return op->then_case;
}
return res;
}
PrimExpr VisitExpr_(const NENode* op) final {
// eager check NE for device check
PrimExpr res = StmtExprMutator::VisitExpr_(op);
op = res.as<NENode>();
if (tir::ExprDeepEqual()(op->a, op->b)) {
return make_const(op->dtype, false);
}
return res;
}
PrimExpr VisitExpr_(const VarNode* op) final {
if (op == var_) {
return make_const(op->dtype, device_type_);
} else {
return GetRef<PrimExpr>(op);
}
}
public:
const VarNode* var_{nullptr};
int device_type_;
};
namespace transform {
Pass BindDeviceType() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "BindDeviceType: Require the target attribute";
n->body = DeviceTypeBinder(target.value()->device_type)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.BindDeviceType", {});
}
TVM_REGISTER_GLOBAL("tir.transform.BindDeviceType")
.set_body_typed(BindDeviceType);
} // namespace transform
} // namespace tir
} // namespace tvm
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/buffer.h> #include <tvm/tir/buffer.h>
#include <tvm/target/target.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h> #include <tvm/runtime/container.h>
...@@ -50,6 +51,12 @@ PrimFunc MakePackedAPI(PrimFunc&& func, ...@@ -50,6 +51,12 @@ PrimFunc MakePackedAPI(PrimFunc&& func,
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol); auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol) CHECK(global_symbol)
<< "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "MakePackedAPI: Require the target attribute";
int target_device_type = target.value()->device_type;
std::string name_hint = global_symbol.value(); std::string name_hint = global_symbol.value();
auto* func_ptr = func.CopyOnWrite(); auto* func_ptr = func.CopyOnWrite();
...@@ -68,7 +75,8 @@ PrimFunc MakePackedAPI(PrimFunc&& func, ...@@ -68,7 +75,8 @@ PrimFunc MakePackedAPI(PrimFunc&& func,
// The arguments of the function. // The arguments of the function.
Array<Var> args; Array<Var> args;
// The device context // The device context
Var device_type("dev_type"), device_id("dev_id"); Var device_id("dev_id");
Integer device_type(target_device_type);
// seq_init gives sequence of initialization // seq_init gives sequence of initialization
// seq_check gives sequence of later checks after init // seq_check gives sequence of later checks after init
std::vector<Stmt> seq_init, seq_check; std::vector<Stmt> seq_init, seq_check;
...@@ -195,17 +203,18 @@ PrimFunc MakePackedAPI(PrimFunc&& func, ...@@ -195,17 +203,18 @@ PrimFunc MakePackedAPI(PrimFunc&& func,
// Set device context // Set device context
if (vmap.count(device_id.get())) { if (vmap.count(device_id.get())) {
PrimExpr node = StringImmNode::make("default"); PrimExpr node = StringImmNode::make("default");
CHECK(vmap.count(device_type.get()));
seq_check.push_back(AttrStmtNode::make( seq_check.push_back(AttrStmtNode::make(
node, attr::device_context_id, device_id, nop)); node, attr::device_context_id, device_id, nop));
seq_check.push_back(AttrStmtNode::make( seq_check.push_back(AttrStmtNode::make(
node, attr::device_context_type, device_type, nop)); node, attr::device_context_type, device_type, nop));
Stmt set_device = IfThenElseNode::make(
device_type != kDLCPU, EvaluateNode::make(CallNode::make( if (runtime::DeviceAPI::NeedSetDeviceContext(target_device_type)) {
DataType::Int(32), intrinsic::tvm_call_packed, Stmt set_device = EvaluateNode::make(CallNode::make(
{StringImmNode::make(runtime::symbol::tvm_set_device), DataType::Int(32), intrinsic::tvm_call_packed,
device_type, device_id}, CallNode::Intrinsic))); {StringImmNode::make(runtime::symbol::tvm_set_device),
body = SeqStmt({set_device, body}); device_type, device_id}, CallNode::Intrinsic));
body = SeqStmt({set_device, body});
}
} }
func_ptr->body = MergeNest( func_ptr->body = MergeNest(
{seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); {seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
......
...@@ -39,10 +39,8 @@ def test_dltensor_compatible(): ...@@ -39,10 +39,8 @@ def test_dltensor_compatible():
A[i + 1] = A[i] + 1 A[i + 1] = A[i] + 1
stmt = ib.get() stmt = ib.get()
mod = tvm.testing.MakeAPILegacy(stmt, "arange", [Ab], 0, True) mod = tvm.testing.MakeAPILegacy(stmt, "arange", [Ab], 0, True)
mod = tvm.tir.transform.LowerTVMBuiltin()(mod) f = tvm.build(mod, target="stackvm")
f = tvm.target.codegen.build_module(mod, "stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
aview = MyTensorView(a) aview = MyTensorView(a)
f(aview) f(aview)
......
...@@ -111,7 +111,7 @@ def test_llvm_lookup_intrin(): ...@@ -111,7 +111,7 @@ def test_llvm_lookup_intrin():
x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z]) x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.v8i8", tvm.tir.const(1, 'uint32'), A[z])
ib.emit(x) ib.emit(x)
body = ib.get() body = ib.get()
func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 1, True) func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 0, True)
fcode = tvm.build(func, None, "llvm") fcode = tvm.build(func, None, "llvm")
......
...@@ -44,7 +44,7 @@ def lower(sch, args): ...@@ -44,7 +44,7 @@ def lower(sch, args):
f = tvm.tir.PrimFunc(arg_list, stmt).with_attr( f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
"global_symbol", tvm.runtime.String("test")) "global_symbol", tvm.runtime.String("test"))
mod = tvm.IRModule({"test": f}) mod = tvm.IRModule({"test": f})
return tvm.tir.transform.MakePackedAPI()(mod) return mod
# All computations are bound. # All computations are bound.
......
...@@ -40,9 +40,10 @@ def test_double_buffer(): ...@@ -40,9 +40,10 @@ def test_double_buffer():
stmt = tvm.tir.ir_pass.Simplify(stmt) stmt = tvm.tir.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.tir.Allocate) assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2 assert stmt.body.body.extents[0].value == 2
mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 2, True) mod = tvm.IRModule({
"db" : tvm.tir.PrimFunc([A.asobject(), C.asobject()], stmt)
})
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0] count = [0]
def count_sync(op): def count_sync(op):
if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync": if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
......
...@@ -92,7 +92,7 @@ def test_flatten_double_buffer(): ...@@ -92,7 +92,7 @@ def test_flatten_double_buffer():
stmt = tvm.tir.ir_pass.Simplify(stmt) stmt = tvm.tir.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.tir.Allocate) assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2 assert stmt.body.body.extents[0].value == 2
mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 2, True) mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 0, True)
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0] count = [0]
......
...@@ -36,7 +36,10 @@ def test_for(): ...@@ -36,7 +36,10 @@ def test_for():
ib.emit(tvm.tir.call_extern ib.emit(tvm.tir.call_extern
("int32", "fadd", device_context(0), A)) ("int32", "fadd", device_context(0), A))
body = ib.get() body = ib.get()
mod = tvm.testing.MakeAPILegacy(body, "func", [dev_type, n], 2, True) mod = tvm.IRModule({
"func" : tvm.tir.PrimFunc([dev_type, n], body)
})
mod = tvm.tir.transform.CombineContextCall()(mod) mod = tvm.tir.transform.CombineContextCall()(mod)
assert mod["func"].body.value.dtype == "handle" assert mod["func"].body.value.dtype == "handle"
......
...@@ -35,8 +35,10 @@ def test_makeapi(): ...@@ -35,8 +35,10 @@ def test_makeapi():
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64) stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64)
num_unpacked_args = 2 num_unpacked_args = 2
f = tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr( f = tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt)
"tir.noalias", True).with_attr("global_symbol", tvm.runtime.String("myadd")) f = f.with_attr("global_symbol", "myadd")
f = f.with_attr("target", tvm.target.create("llvm"))
mod = tvm.IRModule.from_expr(f) mod = tvm.IRModule.from_expr(f)
f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"] f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"]
assert(len(f.params) == 7) assert(len(f.params) == 7)
......
...@@ -60,7 +60,7 @@ def test_cow_pass(): ...@@ -60,7 +60,7 @@ def test_cow_pass():
del func del func
# copy on write # copy on write
mod_hash = mod.__hash__() mod_hash = mod.__hash__()
mod = tvm.ir.transform.Sequential( mod = tvm.transform.Sequential(
[pidentity, tvm.tir.transform.NarrowDataType(32)])(mod._move()) [pidentity, tvm.tir.transform.NarrowDataType(32)])(mod._move())
assert mod_hash == mod.__hash__() assert mod_hash == mod.__hash__()
assert func_hash == mod["main"].__hash__() assert func_hash == mod["main"].__hash__()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment