Unverified Commit e63e08fe by Tianqi Chen Committed by GitHub

[REFACTOR][TIR] Migrate all low-level passes to the Pass Manager. (#5233)

* [REFACTOR][TIR] Migrate all low-level passes to the Pass Manager.

This PR migrates the tvm.lower to return IRModule of PrimFuncs
instead of the LoweredFuncs.

* Remove LoweredFunc.
parent fd9ce583
......@@ -46,7 +46,6 @@ def __lldb_init_module(debugger, _):
"tvm::IterVarAttr",
"tvm::IterVarRelation",
"tvm::Layout",
"tir::LoweredFunc",
"tvm::Map",
"tvm::Map",
"tvm::MemoryInfo",
......
......@@ -145,15 +145,6 @@ After lowering is done, ``build()`` function generates target machine code from
Code generation is done by ``build_module()`` function, defined in ``python/tvm/target/codegen.py``. On the C++ side, code generation is implemented in ``src/target/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/target/codegen/codegen.cc``:
::
runtime::Module Build(const Array<LoweredFunc>& funcs,
const std::string& target) {
std::string build_f_name = "codegen.build_" + target;
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
runtime::Module m = (*bf)(funcs, target);
return m;
}
The ``Build()`` function looks up the code generator for the given target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this:
......
......@@ -32,8 +32,8 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>
#include <tvm/support/with.h>
#include <tvm/ir/module.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/lowered_func.h>
#include <string>
#include <vector>
......@@ -43,15 +43,15 @@
namespace tvm {
/*!
* \brief Build a LoweredFunc given a schedule, args and binds
* \brief Build an IRModule given a schedule, args and binds
* \param sch The schedule to lower.
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param config The build configuration.
* \return The lowered function.
* \return The result module.
*/
TVM_DLL Array<tir::LoweredFunc> lower(
TVM_DLL IRModule lower(
te::Schedule sch,
const Array<te::Tensor>& args,
const std::string& name,
......@@ -59,44 +59,43 @@ TVM_DLL Array<tir::LoweredFunc> lower(
const BuildConfig& config);
/*!
* \brief Build a device and host module for a specific target from an array of lowered functions.
* \brief Build a device and host module for a specific target from an IRModule.
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \return The built module.
*/
TVM_DLL runtime::Module build(const Array<tir::LoweredFunc>& funcs,
TVM_DLL runtime::Module build(const IRModule& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config);
/*!
* \brief Build a device and host module for a specific target from a map
* contains target to a list of lowered functions pairs. This function is used
* contains target to IRModule. This function is used
* for heterogeneous build.
* \param input The map contains target to a list of lowered functions pairs.
* \param input The map contains target to an IRModule.
* \param target_host The target for building host code. To use the default,
* pass Target().
* \param config The build configuration.
* \return The built module that contains code for different processors.
*/
TVM_DLL runtime::Module build(const Map<Target, Array<tir::LoweredFunc>>& input,
TVM_DLL runtime::Module build(const Map<Target, IRModule>& input,
const Target& target_host,
const BuildConfig& config);
/*!
* \brief Build a device and host module for a specific target from a map
* contains target to a list of lowered functions pairs. This function is used
* contains target to IRModule. This function is used
* for heterogeneous build.
* \param input The map contains target string to a list of lowered functions
* pairs.
* \param input The map contains target string to an IRModule.
* \param target_host The target for building host code. To use the default,
* pass Target().
* \param config The build configuration.
* \return The built module that contains code for different processors.
*/
TVM_DLL runtime::Module build(const Map<std::string, Array<tir::LoweredFunc>>& input,
TVM_DLL runtime::Module build(const Map<std::string, IRModule>& input,
const Target& target_host,
const BuildConfig& config);
} // namespace tvm
......
......@@ -297,6 +297,15 @@ class IRModule : public ObjectRef {
CHECK(ptr != nullptr);
return static_cast<IRModuleNode*>(ptr);
}
/*!
* \brief Construct an empty module.
*
* \returns The constructed module
*/
static IRModule Empty() {
return IRModule(Map<GlobalVar, BaseFunc>());
}
/*!
* \brief Construct a module from a standalone expression.
*
......
......@@ -27,7 +27,6 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/ir/module.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/target/target.h>
#include <string>
......@@ -42,17 +41,6 @@ using runtime::TVMArgs;
using runtime::TVMRetValue;
/*!
* \brief Temporary backward compatible function to convert a list
* of LoweredFunc to a IRModule of PrimfFuncs
* \param funcs The input lowered function.
* \return The IRModule.
*
* \note This function is only used for code refactor and will be
* removed once the refactor completes.
*/
IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs);
/*!
* \brief Build a module from array of lowered function.
* \param mod The Module to be built
* \param target The target to be built.
......
......@@ -24,9 +24,12 @@
#ifndef TVM_TIR_ANALYSIS_H_
#define TVM_TIR_ANALYSIS_H_
#include <tvm/ir/module.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>
namespace tvm {
namespace tir {
......@@ -59,6 +62,18 @@ struct ExprDeepEqual {
*/
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
* In the case that tgt is cuda, if not all workload is bound with
* threads, CPU code is generated that tries to access GPU memory,
* which is illegal. This pass performs verification for this case.
*
* \param mod The module to be verified.
* \return Success of memory verification.
*/
void VerifyMemory(const IRModule& mod);
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_ANALYSIS_H_
......@@ -31,7 +31,6 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/function.h>
#include <tvm/tir/lowered_func.h>
#include <unordered_map>
#include <unordered_set>
......@@ -367,60 +366,6 @@ Stmt HoistIfThenElse(Stmt stmt);
Stmt NarrowDataType(Stmt stmt, int target_bits);
/*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
* - Map the values in the api_args to Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_unpacked_args Number of arguments that
* are processed in plain form instead of packed form.
* \param is_restricted Whether the caller can guarantee that each buffer argument do not overlap.
* It is recommended to set to true for optimized code if such invariant holds.
*
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signature have two cases
*
* let num_packed_args = len(api_args) - num_unpacked_args;
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n,
* TVMValue* out_ret_val, int* out_ret_tcode)
*
* where n == len(api_args), k == num_packed_args
*
* There is no thread_axis in generated function.
*/
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<ObjectRef> api_args,
int num_unpacked_args,
bool is_restricted);
/*!
* \brief Remap the thread axis
*
* This can be used to get equivalent program which uses
* threadIdx.y in place of threadIdx.x by passing
* {"threadIdx.x": thread_axis("threadIdx.y")}
*
*
* \param f The device function to be lowered.
* \param axis_map The map from StringImm -> ItrVar
* \return Transformed function.
*/
LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
......@@ -433,31 +378,6 @@ LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
PrimFunc PointerValueTypeRewrite(PrimFunc f);
/*!
* \brief Lower custom datatypes.
*
* See tvm::datatypes::Registry for more information on adding custom datatypes.
*
* \param f The device function to be lowered.
* \param target The target device.
* \return Transformed function.
*/
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
* In the case that tgt is cuda, if not all workload is bound with
* threads, CPU code is generated that tries to access GPU memory,
* which is illegal. This pass performs verification for this case.
*
* \param func The function to be verified.
* \param device_type The target device type.
* \return Success of memory verification.
*/
bool VerifyMemory(LoweredFunc func, int device_type);
/*!
* \brief Verify the correctness of a GPU code
* It will check the whether the amount of memory usage or the number of threads
* in a block exceeds the limit
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/tir/lowered_func.h
* \brief Information about a lowered TVM function.
* This data structure is final step toward codegen.
*/
#ifndef TVM_TIR_LOWERED_FUNC_H_
#define TVM_TIR_LOWERED_FUNC_H_
#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <string>
namespace tvm {
namespace tir {
// Internal node container of lowered function.
class LoweredFuncNode;
/*!
* \brief LoweredFunc represents function after lowering.
* This is the final IR representation before codegen.
*/
class LoweredFunc : public FunctionRef {
public:
LoweredFunc() {}
explicit LoweredFunc(ObjectPtr<Object> n) : FunctionRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const LoweredFuncNode* operator->() const;
/*! \brief specify container node */
using ContainerType = LoweredFuncNode;
};
/*! \brief specific type of lowered function */
enum LoweredFuncType : int {
/*! \brief Function that can mix device and host calls */
kMixedFunc = 0,
/*! \brief Only contains host code */
kHostFunc = 1,
/*! \brief Only contains device code */
kDeviceFunc = 2
};
/*! \brief Node container of LoweredFunc */
class LoweredFuncNode : public tir::FunctionBaseNode {
public:
/*! \brief The name of the function */
std::string name;
/*!
* \brief The arguments of the function
* This function can only take pod type(int, float) and void* as arguments.
*/
Array<Var> args;
/*!
* \brief The IterVar axis of threads
* Each axis need host function to specify a size.
* \note Calling convention into LoweredFunc
*
* Assume we have a LoweredFunc f, a call into f
* Call(f, arg1, arg2, ..., arg_n,
* size_axis_1, size_axis_2, ... size_axis_m)
*
* Here n = len(args), m = len(thread_axis)
*
* The CodeGen should take this and translate this call
* to corresponding API specific kernel launchs or function calls.
*/
Array<IterVar> thread_axis;
/*!
* \brief The hint data type of Var handles defined in LetStmt
* Can be used as hint when generating type signiture.
* The creation rule is given by
* handle_data_type[var_handle] = make_const(the_type, 0);
*
* \note Expr is used instead Type, because Type cannot be hold by Map.
* constant Expr of given type is used.
*/
Map<Var, PrimExpr> handle_data_type;
/*! \brief The type of the function */
LoweredFuncType func_type{kMixedFunc};
/*! \brief Whether this function is packed function */
bool is_packed_func{true};
/*!
* \brief Whether function ensures that argument pointers do not alias.
* This corresponds to restrict keyword in C.
*/
bool is_restricted{true};
/*! \brief The body statment of the function */
Stmt body;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
// there is no return value, but return 1
// to enable Call into this function.
int num_outputs() const final {
return 1;
}
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("thread_axis", &thread_axis);
v->Visit("handle_data_type", &handle_data_type);
v->Visit("func_type", &func_type);
v->Visit("is_packed_func", &is_packed_func);
v->Visit("is_restricted", &is_restricted);
v->Visit("body", &body);
}
static constexpr const char* _type_key = "LoweredFunc";
TVM_DECLARE_FINAL_OBJECT_INFO(LoweredFuncNode, Object);
};
// Implementations of inline functions
inline const LoweredFuncNode* LoweredFunc::operator->() const {
return static_cast<const LoweredFuncNode*>(get());
}
} // namespace tir
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::tir::LoweredFunc> : public tvm::ObjectHash {
};
}
#endif // TVM_TIR_LOWERED_FUNC_H_
......@@ -59,6 +59,61 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const tvm::Array<tvm::PrimExpr>& required);
/*!
* \brief Transform the high-level PrimFunc to a low-level version
* that can be used as an API function.
*
*
* The main task of this function is to create code to :
* - Map the values in the api_args to Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param num_unpacked_args Number of arguments that
* are processed in plain form instead of packed form.
*
* \note
* The function signature have two cases
*
* let num_packed_args = len(api_args) - num_unpacked_args;
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n,
* TVMValue* out_ret_val, int* out_ret_tcode)
*
* where n == len(api_args), k == num_packed_args
*
* \return The pass.
*/
TVM_DLL Pass MakePackedAPI(int num_unpacked_args);
/*!
* \brief Remap the thread axis
*
* This can be used to get equivalent program which uses
* threadIdx.y in place of threadIdx.x by passing
* {"threadIdx.x": thread_axis("threadIdx.y")}
*
*
* \return The pass.
*/
TVM_DLL Pass RemapThreadAxis(Map<PrimExpr, IterVar> axis_map);
/*!
* \brief Lower custom datatypes.
*
* See tvm::datatypes::Registry for more information on adding custom datatypes.
*
* \return The pass.
*/
TVM_DLL Pass LowerCustomDatatypes();
/*!
* \brief Bind the device type ofthe function to be
* the device_type specified in the target attribute.
*
......
......@@ -17,9 +17,6 @@
# pylint: disable=invalid-name
"""The build utils in python.
This module provides the functions to transform schedule to
LoweredFunc and compiled Module.
"""
import warnings
......@@ -30,7 +27,6 @@ from tvm.ir import container
from tvm.ir import CallingConv
from tvm.target import codegen, BuildConfig
from tvm.tir import ir_pass
from tvm.tir.stmt import LoweredFunc
from tvm.te import tensor
from tvm.te import schedule
from tvm import target as _target
......@@ -136,8 +132,8 @@ def lower(sch,
Returns
-------
f : LoweredFunc or Stmt
The result function, if with_api_wrapper=False
m : IRModule or Stmt
The result IRModule, if simple_mode=False
Then the Stmt before make api is returned.
"""
cfg = BuildConfig.current()
......@@ -199,16 +195,21 @@ def lower(sch,
if simple_mode:
return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
if cfg.restricted_func:
f = f.with_attr("tir.no_alias", True)
mod = tvm.IRModule({name: f})
return tvm.tir.transform.MakePackedAPI()(mod)
def _build_for_device(flist, target, target_host):
def _build_for_device(input_mod, target, target_host):
"""Build the lowered functions for a device with the given compilation
target.
Parameters
----------
flist : list of LoweredFunc
input_mod : IRModule
The schedule to be built.
target : str or :any:`tvm.target.Target`
......@@ -219,8 +220,8 @@ def _build_for_device(flist, target, target_host):
Returns
-------
fhost : list of LoweredFunc
A list of lowered functions for the host.
fhost : IRModule
The host IRModule.
mdev : tvm.module
A module that contains device code.
......@@ -229,14 +230,13 @@ def _build_for_device(flist, target, target_host):
target_host = _target.create(target_host)
device_type = ndarray.context(target.target_name, 0).device_type
for func in flist:
if not ir_pass.VerifyMemory(func, device_type):
raise ValueError(
"Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name)
mod_mixed = input_mod
mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)
tvm.tir.analysis.verify_memory(mod_mixed)
mod_mixed = tvm.testing.LoweredFuncsToIRModule(flist)
opt_mixed = [tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))]
opt_mixed = []
if len(mod_mixed.functions) == 1:
opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]
if BuildConfig.current().detect_global_barrier:
opt_mixed += [tvm.tir.transform.ThreadSync("global")]
opt_mixed += [tvm.tir.transform.ThreadSync("shared"),
......@@ -292,7 +292,7 @@ def build(inputs,
Parameters
----------
inputs : tvm.te.Schedule, LoweredFunc, or dict of target to LoweredFunc list
inputs : tvm.te.Schedule, IRModule, or dict of target to IRModule
The schedule to be built
args : list of Buffer or Tensor or Var, optional
......@@ -326,7 +326,7 @@ def build(inputs,
________
There are two typical example uses of this function depending on the type
of the argument `inputs`:
1. it is a list of lowered functions:
1. it is an IRModule.
.. code-block:: python
......@@ -335,10 +335,10 @@ def build(inputs,
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.te.create_schedule(C.op)
f = tvm.lower(s, [A, B, C], name="test_add")
m = tvm.build(f, target="llvm")
m = tvm.lower(s, [A, B, C], name="test_add")
rt_mod = tvm.build(m, target="llvm")
2. it is a dict of compilation target to list of lowered functions:
2. it is a dict of compilation target to IRModule.
.. code-block:: python
......@@ -349,9 +349,9 @@ def build(inputs,
s1 = tvm.te.create_schedule(C.op)
with tvm.target.cuda() as cuda_tgt:
s2 = topi.cuda.schedule_injective(cuda_tgt, [C])
f1 = tvm.lower(s1, [A, B, C], name="test_add1")
f2 = tvm.lower(s2, [A, B, C], name="test_add2")
m = tvm.build({"llvm": [f1], "cuda": [f2]}, target_host="llvm")
m1 = tvm.lower(s1, [A, B, C], name="test_add1")
m2 = tvm.lower(s2, [A, B, C], name="test_add2")
rt_mod = tvm.build({"llvm": m1, "cuda": m2}, target_host="llvm")
Note
----
......@@ -360,45 +360,36 @@ def build(inputs,
if isinstance(inputs, schedule.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
flist = lower(inputs, args,
input_mod = lower(inputs, args,
name=name,
binds=binds)
if isinstance(flist, LoweredFunc):
flist = [flist]
elif isinstance(inputs, LoweredFunc):
if args:
raise ValueError("args must be done when build from LoweredFunc.")
flist = [inputs]
elif isinstance(inputs, (list, tuple, container.Array)):
flist = inputs
merged_mod = tvm.IRModule({})
for x in inputs:
merged_mod.update(x)
input_mod = merged_mod
elif isinstance(inputs, tvm.IRModule):
input_mod = inputs
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError("inputs must be Schedule, LoweredFunc, list of "
"LoweredFunc, or dict of target to list of "
"LoweredFunc.")
raise ValueError("inputs must be Schedule, IRModule or dict of target to IRModule")
if not isinstance(inputs, (dict, container.Map)):
target = _target.Target.current() if target is None else target
target = target if target else "llvm"
target_flist = {target: flist}
target_input_mod = {target: input_mod}
else:
target_flist = inputs
target_input_mod = inputs
for tar, flist in target_flist.items():
for tar, mod in target_input_mod.items():
if not isinstance(tar, (str, _target.Target)):
raise ValueError("The key of inputs must be str or "
"_target.Target when inputs is dict.")
fname_set = set()
for x in flist:
if not isinstance(x, LoweredFunc):
raise ValueError("inputs must be Schedule, LoweredFunc, list "
"of LoweredFunc, or dict of str to list of "
"LoweredFunc.")
if x.name in fname_set:
raise ValueError("Duplicate function name %s" % x.name)
fname_set.add(x.name)
if not isinstance(mod, tvm.IRModule):
raise ValueError("inputs must be Schedule, IRModule,"
"or dict of str to IRModule.")
if not target_host:
for tar, _ in target_flist.items():
for tar, _ in target_input_mod.items():
tar = _target.create(tar)
device_type = ndarray.context(tar.target_name, 0).device_type
if device_type == ndarray.cpu(0).device_type:
......@@ -410,8 +401,8 @@ def build(inputs,
mod_host_all = tvm.IRModule({})
device_modules = []
for tar, flist in target_flist.items():
mod_host, mdev = _build_for_device(flist, tar, target_host)
for tar, input_mod in target_input_mod.items():
mod_host, mdev = _build_for_device(input_mod, tar, target_host)
mod_host_all.update(mod_host)
device_modules.append(mdev)
......
......@@ -17,7 +17,6 @@
"""The interface of expr function exposed from C++."""
import tvm._ffi
import tvm.driver
from tvm.ir import container as _container
@tvm._ffi.register_func("relay.backend.lower")
......@@ -40,7 +39,7 @@ def lower(sch, inputs, func_name, source_func):
Returns
-------
lowered_funcs : List[tvm.LoweredFunc]
mod : tvm.IRModule
The result of lowering.
"""
# pylint: disable=broad-except, import-outside-toplevel
......@@ -56,20 +55,17 @@ def lower(sch, inputs, func_name, source_func):
msg += "-----------------------------\n"
msg += source_func.astext()
raise RuntimeError(msg)
return f if isinstance(
f, (_container.Array, tuple, list)) else [f]
return f
@tvm._ffi.register_func("relay.backend.build")
def build(funcs, target, target_host=None):
def build(mod, target, target_host=None):
"""Backend build function.
Parameters
----------
funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
A list of lowered functions or dictionary mapping from targets to
lowered functions.
mod : tvm.IRModule or Dict[str, tvm.IRModule]
Input module
target : tvm.Target
The target to run the code on.
......@@ -84,7 +80,7 @@ def build(funcs, target, target_host=None):
"""
if target_host == "":
target_host = None
return tvm.driver.build(funcs, target=target, target_host=target_host)
return tvm.driver.build(mod, target=target, target_host=target_host)
@tvm._ffi.register_func("relay._tensor_value_repr")
......
......@@ -48,7 +48,7 @@ class GraphRuntimeCodegen(object):
self._get_graph_json = self._mod["get_graph_json"]
self._list_params_name = self._mod["list_params_name"]
self._get_param_by_name = self._mod["get_param_by_name"]
self._get_lowered_funcs = self._mod["get_lowered_funcs"]
self._get_irmodule = self._mod["get_irmodule"]
self._setup(mod, target)
def _setup(self, mod, target):
......@@ -74,14 +74,14 @@ class GraphRuntimeCodegen(object):
-------
graph_json : str
The graph json that can be consumed by runtime.
lowered_funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]]
mod : IRModule or Dict[str, IRModule]
The lowered functions.
params : Dict[str, tvm.nd.NDArray]
Additional constant parameters.
"""
self._codegen(func)
graph_json = self._get_graph_json()
lowered_func = self._get_lowered_funcs()
lowered_func = self._get_irmodule()
param_names = self._list_params_name()
params = {}
for name in param_names:
......
......@@ -28,3 +28,4 @@ from .object_generic import convert_to_object, convert, const
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .module import load_module, enabled, system_lib
from .container import String
......@@ -20,9 +20,7 @@ import tvm._ffi
import tvm.ir
from tvm.runtime import Object
from tvm.ir import container
from tvm.tir import Stmt
from tvm.tir.stmt import LoweredFunc
from . import _ffi_api
......@@ -48,17 +46,13 @@ class DumpIR(object):
def dump(*args, **kwargs):
"""dump function"""
retv = func(*args, **kwargs)
if not isinstance(retv, (Stmt, LoweredFunc, container.Array)):
if not isinstance(retv, (Stmt,)):
return retv
fname = func.func_name if hasattr(func, 'func_name') else func.__name__
pname = str(self._pass_id) + "_" + fname + "_ir.cc"
with open(pname, "a") as f:
out = retv.body if isinstance(retv, LoweredFunc) else retv
out = retv
f.write(str(out))
if isinstance(retv, container.Array):
for x in retv:
out = x.body if isinstance(x, LoweredFunc) else x
f.write("---------%s\n%s\n-----------\n"%(x.name, str(out)))
self._pass_id += 1
return retv
return dump
......
......@@ -14,9 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
""" TVM testing utilities """
import logging
import numpy as np
import tvm
import tvm._ffi
......@@ -165,4 +168,40 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
x_name, grad.shape, dist, max_diff, avg_diff)
def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to build a Module from statement.
Used for migrating existing test cases only.
Parameters
----------
stmt: Stmt
The input statement.
name: str
The name of the funciton.
args: list of Buffer or Vars
The function arguments
num_unpacked_args: int
Number of unpacked arguments.
nolias: bool
Whether allow noalias.
Returns
-------
mod : IRModule
The created IRModule.
"""
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.no_alias", True)
mod = tvm.IRModule({name: f})
return tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)
tvm._ffi._init_api("testing", __name__)
......@@ -29,7 +29,7 @@ from .expr import IterVar, Any
from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .function import PrimFunc
......
......@@ -55,3 +55,14 @@ def expr_deep_equal(lhs, rhs):
tvm.ir.structural_equal
"""
return _ffi_api.expr_deep_equal(lhs, rhs)
def verify_memory(mod):
"""Verify if module contains illegal host side direct memory access.
Parameters
----------
mod: tvm.IRModule
The module to be verified.
"""
_ffi_api.verify_memory(mod)
......@@ -18,6 +18,7 @@
import tvm._ffi
import tvm.runtime
from tvm.runtime import Object
from tvm.ir import BaseFunc
from .buffer import Buffer
from .expr import Var
......@@ -54,6 +55,7 @@ class PrimFunc(BaseFunc):
param_list = []
buffer_map = {} if buffer_map is None else buffer_map
for x in params:
x = tvm.runtime.convert(x) if not isinstance(x, Object) else x
if isinstance(x, Buffer):
var = Var(x.name, dtype="handle")
param_list.append(var)
......
......@@ -385,14 +385,6 @@ class Prefetch(Stmt):
_ffi_api.Prefetch, func, value_index, dtype, bounds)
@tvm._ffi.register_object
class LoweredFunc(Object):
"""Represent a LoweredFunc in TVM."""
MixedFunc = 0
HostFunc = 1
DeviceFunc = 2
def stmt_seq(*args):
"""Make sequence of statements
......
......@@ -60,6 +60,36 @@ def Filter(fcond):
return _fpass.prim_func_pass(_transform, opt_level=0)
def LowerCustomDatatypes():
"""Lower custom datatypes.
See tvm::datatypes::Registry for more information on adding custom datatypes.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.LowerCustomDatatypes()
def MakePackedAPI(num_unpacked_params=0):
"""Transform the PrimFuncs in the module to a packed func API.
Parameters
----------
num_unpacked_params : int
Number of parameters that we hope to directly pass via normal arguments
following the PackedFunc input signature.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.MakePackedAPI(num_unpacked_params)
def BindDeviceType():
"""Bind the device type of the function to be
the device_type specified in the target attribute.
......
......@@ -27,7 +27,6 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/te/schedule.h>
#include <map>
#include <string>
......
......@@ -26,8 +26,10 @@
#include <tvm/te/operation.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/target/codegen.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
#include <algorithm>
......@@ -39,7 +41,6 @@ namespace tvm {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
using tir::LoweredFunc;
bool LLVMEnabled() {
const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm");
......@@ -166,17 +167,6 @@ tir::Stmt BuildStmt(te::Schedule sch,
return stmt;
}
Array<LoweredFunc> lower(te::Schedule sch,
const Array<te::Tensor>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
const BuildConfig& config) {
Array<ObjectRef> out_arg_list;
auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
return Array<LoweredFunc>({ tir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) });
}
transform::Pass BindTarget(Target target) {
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
return WithAttr(std::move(f), tvm::attr::kTarget, target);
......@@ -198,18 +188,46 @@ transform::Pass FilterBy(FCond fcond) {
}
IRModule lower(te::Schedule sch,
const Array<te::Tensor>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
const BuildConfig& config) {
Array<ObjectRef> out_arg_list;
auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
Array<tir::Var> params;
Map<tir::Var, tir::Buffer> buffer_map;
for (auto var : out_arg_list) {
if (auto* n = var.as<tir::VarNode>()) {
params.push_back(GetRef<tir::Var>(n));
} else {
tir::Buffer buffer = Downcast<tir::Buffer>(var);
tir::Var bptr(buffer->name, DataType::Handle());
params.push_back(bptr);
buffer_map.Set(bptr, buffer);
}
}
auto f = tir::PrimFunc(params, stmt, VoidType(), buffer_map);
f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
if (config->restricted_func) {
f = WithAttr(std::move(f), "tir.no_alias", Integer(1));
}
auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
return tir::transform::MakePackedAPI(0)(mod);
}
std::pair<IRModule, IRModule>
split_dev_host_funcs(const Array<LoweredFunc>& funcs,
split_dev_host_funcs(IRModule mod_mixed,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
for (const auto& x : funcs) {
CHECK(tir::VerifyMemory(x, target->device_type))
<< "Direct host side access to device memory is detected in "
<< x->func_name() << ". Did you forget to bind?";
}
IRModule mod_mixed = codegen::ToIRModule(funcs);
mod_mixed = BindTarget(target)(std::move(mod_mixed));
tir::VerifyMemory(mod_mixed);
Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target)};
if (config->detect_global_barrier) {
......@@ -274,10 +292,9 @@ split_dev_host_funcs(const Array<LoweredFunc>& funcs,
// Build for heterogeneous execution.
runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs,
runtime::Module build(const Map<Target, IRModule>& inputs,
const Target& target_host,
const BuildConfig& config) {
Array<LoweredFunc> fhost_all;
std::vector<runtime::Module> device_modules;
Target target_host_val = target_host;
......@@ -319,10 +336,10 @@ runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs,
}
// Build for heterogeneous execution when target is a string.
runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs,
runtime::Module build(const Map<std::string, IRModule>& inputs,
const Target& target_host,
const BuildConfig& config) {
Map<Target, Array<LoweredFunc>> updated_input;
Map<Target, IRModule> updated_input;
for (const auto& it : inputs) {
auto target = Target::Create(it.first);
if (target->device_name == "vta") {
......@@ -334,11 +351,11 @@ runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs,
}
// Build for homogeneous execution.
runtime::Module build(const Array<LoweredFunc>& funcs,
runtime::Module build(const IRModule& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
Map<Target, Array<LoweredFunc>> inputs = {{target, funcs}};
Map<Target, IRModule> inputs = {{target, funcs}};
return build(inputs, target_host, config);
}
......
......@@ -38,7 +38,6 @@ namespace tvm {
namespace relay {
namespace backend {
using tir::LoweredFunc;
using TargetsMap = Map<tvm::Integer, tvm::Target>;
using namespace tvm::relay::transform;
......@@ -78,16 +77,16 @@ struct GraphCodegen {
}
Array<tvm::runtime::Module> GetExternalModules() {
return CallFunc<Array<tvm::runtime::Module> >("get_external_modules", nullptr);
return CallFunc<Array<tvm::runtime::Module>>("get_external_modules", nullptr);
}
Map<std::string, Array<LoweredFunc> > GetLoweredFunc() {
return CallFunc<Map<std::string, Array<LoweredFunc> > >("get_lowered_funcs", nullptr);
Map<std::string, IRModule> GetIRModule() {
return CallFunc<Map<std::string, IRModule>>("get_irmodule", nullptr);
}
std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
std::unordered_map<std::string, tvm::runtime::NDArray> ret;
auto names = CallFunc<Array<tvm::PrimExpr> >("list_params_name", nullptr);
auto names = CallFunc<Array<tvm::PrimExpr>>("list_params_name", nullptr);
for (auto expr : names) {
auto key = expr.as<tir::StringImmNode>()->value;
ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
......@@ -152,9 +151,9 @@ class RelayBuildModule : public runtime::ModuleNode {
this->SetParam(kv.first, kv.second->data);
}
});
} else if (name == "get_lowered_funcs") {
} else if (name == "get_irmodule") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->graph_codegen_->GetLoweredFunc();
*rv = this->graph_codegen_->GetIRModule();
});
} else if (name == "get_external_modules") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
......@@ -452,7 +451,7 @@ class RelayBuildModule : public runtime::ModuleNode {
ret_.graph_json = graph_codegen_->GetJSON();
ret_.params = graph_codegen_->GetParams();
auto lowered_funcs = graph_codegen_->GetLoweredFunc();
auto lowered_funcs = graph_codegen_->GetIRModule();
// When there is no lowered_funcs due to reasons such as optimization.
if (lowered_funcs.size() == 0) {
......
......@@ -27,7 +27,6 @@
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
......@@ -82,7 +81,8 @@ struct CachedFuncNode : public Object {
/*! \brief The schedule to the function */
te::Schedule schedule;
/*! \brief The lowered functions to support the function. */
tvm::Array<tir::LoweredFunc> funcs;
IRModule funcs = IRModule::Empty();
/*! \brief Parameter usage states in the shape function. */
tvm::Array<Integer> shape_func_param_states;
......
......@@ -55,7 +55,7 @@ using TargetsMap = std::unordered_map<int, Target>;
/*! \brief Lowered outputs */
struct LoweredOutput {
std::string graph_json;
Map<std::string, Array<tir::LoweredFunc> > lowered_funcs;
Map<std::string, IRModule> lowered_funcs;
Array<tvm::runtime::Module> external_mods;
std::unordered_map<std::string, tvm::runtime::NDArray> params;
};
......@@ -214,19 +214,14 @@ class GraphRuntimeCodegen
LoweredOutput ret;
ret.graph_json = os.str();
ret.params = params_;
for (auto& kv : lowered_funcs_) {
if (ret.lowered_funcs.count(kv.first) == 0) {
ret.lowered_funcs.Set(kv.first, Array<tir::LoweredFunc>());
}
auto& vec = ret.lowered_funcs[kv.first];
Array<tir::LoweredFunc> tmp;
for (auto f : kv.second) {
tmp.push_back(f);
}
for (auto f : vec) {
tmp.push_back(f);
ret.lowered_funcs.Set(kv.first, IRModule::Empty());
}
ret.lowered_funcs.Set(kv.first, tmp);
auto& mod = ret.lowered_funcs[kv.first];
mod->Update(kv.second);
ret.lowered_funcs.Set(kv.first, mod);
}
ret.external_mods = compile_engine_->LowerExternalFunctions();
return ret;
......@@ -457,12 +452,9 @@ class GraphRuntimeCodegen
CCacheKey key = (*pf0)(func, target);
CachedFunc lowered_func = (*pf1)(compile_engine_, key);
if (!lowered_funcs_.count(target->str())) {
lowered_funcs_[target->str()] = {};
lowered_funcs_[target->str()] = IRModule::Empty();
}
for (auto f : lowered_func->funcs) {
lowered_funcs_[target->str()].insert(f);
}
lowered_funcs_[target->str()]->Update(lowered_func->funcs);
return GraphAddCallNode(op,
_GetUniqueName(lowered_func->func_name),
lowered_func->func_name);
......@@ -602,8 +594,7 @@ class GraphRuntimeCodegen
/*! \brief plan memory of device result */
Map<Expr, Array<IntegerArray>> storage_device_map_;
/*! \brief lowered funcs */
std::unordered_map<std::string, std::unordered_set<tir::LoweredFunc, ObjectHash, ObjectEqual>>
lowered_funcs_;
std::unordered_map<std::string, IRModule> lowered_funcs_;
/*! \brief name map */
std::unordered_map<std::string, size_t> name_map_;
/*! \brief compile engine */
......@@ -655,7 +646,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
CHECK_GT(this->output_.params.count(key), 0);
*rv = this->output_.params[key];
});
} else if (name == "get_lowered_funcs") {
} else if (name == "get_irmodule") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->output_.lowered_funcs;
});
......
......@@ -226,6 +226,7 @@ std::vector<int64_t> ToAllocTensorShape32(NDArray shape) {
return raw_shape;
}
class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
public:
VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host)
......@@ -407,12 +408,15 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
CCacheKey key(func, target_host_);
auto cfunc = engine_->LowerShapeFunc(key);
int op_index = -1;
if (context_->seen_funcs.count(cfunc->funcs[0]) == 0) {
// pick the only function inside the context
CHECK_EQ(cfunc->funcs->functions.size(), 1);
auto pfunc = Downcast<tir::PrimFunc>((*cfunc->funcs->functions.begin()).second);
if (context_->seen_funcs.count(pfunc) == 0) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
context_->seen_funcs[cfunc->funcs[0]] = op_index;
context_->seen_funcs[pfunc] = op_index;
} else {
op_index = context_->seen_funcs[cfunc->funcs[0]];
op_index = context_->seen_funcs[pfunc];
}
// Prepare input and output registers
......@@ -494,13 +498,14 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
context_->cached_funcs.push_back(cfunc);
} else {
// TODO(jroesch): support lowered funcs for multiple targets
CHECK_EQ(cfunc->funcs.size(), 1);
if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
CHECK_EQ(cfunc->funcs->functions.size(), 1);
auto pfunc = Downcast<tir::PrimFunc>((*cfunc->funcs->functions.begin()).second);
if (context_->seen_funcs.find(pfunc) == context_->seen_funcs.end()) {
op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc);
context_->seen_funcs[cfunc->funcs[0]] = op_index;
context_->seen_funcs[pfunc] = op_index;
} else {
op_index = context_->seen_funcs[cfunc->funcs[0]];
op_index = context_->seen_funcs[pfunc];
}
}
......@@ -862,11 +867,7 @@ void VMCompiler::Lower(IRModule mod,
// update primitive function map
size_t primitive_index = 0;
for (const auto& cfunc : context_.cached_funcs) {
if (cfunc->target->str() == "ext_dev") {
exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
} else {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
}
}
}
......@@ -961,8 +962,6 @@ void VMCompiler::PopulateGlobalMap() {
}
void VMCompiler::Codegen() {
using tir::LoweredFunc;
if (!context_.module.defined()) {
LOG(WARNING) << "Did you forget to call VMCompiler::Lower?";
return;
......@@ -971,15 +970,21 @@ void VMCompiler::Codegen() {
if (cached_funcs.size() == 0) {
return;
}
std::unordered_map<std::string, Array<LoweredFunc>> funcs;
std::unordered_map<std::string, IRModule> funcs;
for (auto& cfunc : cached_funcs) {
std::string target_str = cfunc->target->str();
// NOTE: because module, is mutable, we need to make an
// explicit copy of the IRModule.
IRModule mod = cfunc->funcs;
mod.CopyOnWrite();
if (target_str == "ext_dev") {
continue;
} else if (funcs.count(target_str) == 0) {
funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
funcs.emplace(target_str, mod);
} else {
funcs[target_str].push_back(cfunc->funcs[0]);
funcs[target_str]->Update(mod);
}
}
......
......@@ -76,7 +76,7 @@ struct VMCompilerContext {
// List of cached functions
std::vector<CachedFunc> cached_funcs;
// The functions that have been lowered.
std::unordered_map<tir::LoweredFunc, size_t, ObjectHash, ObjectEqual> seen_funcs;
std::unordered_map<tir::PrimFunc, size_t, ObjectHash, ObjectEqual> seen_funcs;
};
......
......@@ -22,7 +22,6 @@
* \brief API for Automatic Differentiation for the Relay IR.
*/
#include <tvm/ir/type_functor.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/te/operation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/analysis.h>
......
......@@ -31,29 +31,12 @@
#include <tvm/tir/function.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/lowered_func.h>
#include <unordered_map>
#include <string>
#include "../runtime/meta_data.h"
namespace tvm {
namespace codegen {
// Extract function information from device function.
inline std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const Array<tir::LoweredFunc>& funcs) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (tir::LoweredFunc f : funcs) {
runtime::FunctionInfo info;
for (size_t i = 0; i < f->args.size(); ++i) {
info.arg_types.push_back(f->args[i].dtype());
}
for (size_t i = 0; i < f->thread_axis.size(); ++i) {
info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag);
}
fmap[f->name] = info;
}
return fmap;
}
inline std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule& mod) {
......
......@@ -43,50 +43,6 @@
namespace tvm {
namespace codegen {
// convert legacy LoweredFunc to PrimFunc.
tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) {
// remap args to attach type annotations.
Array<tir::Var> args;
Map<tir::Var, PrimExpr> remap_vars;
for (auto var : from->args) {
auto it = from->handle_data_type.find(var);
if (it != from->handle_data_type.end()) {
tir::Var new_var(var->name_hint,
PointerType(PrimType((*it).second->dtype)));
args.push_back(new_var);
remap_vars.Set(var, new_var);
} else {
args.push_back(var);
}
}
tir::PrimFunc func(args, Substitute(from->body, remap_vars));
func = WithAttr(std::move(func), attr::kGlobalSymbol, runtime::String(from->name));
func = WithAttr(std::move(func), tir::attr::kDeviceThreadAxis, from->thread_axis);
if (from->func_type == tir::LoweredFuncType::kDeviceFunc) {
func = WithAttr(std::move(func),
attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch));
}
if (from->is_restricted) {
func = WithAttr(std::move(func), tir::attr::kNoAlias, Integer(1));
}
return func;
}
IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs) {
Map<GlobalVar, BaseFunc> functions;
for (size_t i = 0; i < funcs.size(); ++i) {
auto f = funcs[i];
tir::PrimFunc pf = ToPrimFunc(f);
if (i == 0) {
pf = WithAttr(std::move(pf), tir::attr::kIsEntryFunc, Integer(1));
}
functions.Set(GlobalVar(f->name), pf);
}
return IRModule(functions);
}
runtime::Module Build(IRModule mod, const Target& target) {
if (BuildConfig::Current()->disable_assert) {
mod = tir::transform::SkipAssert()(mod);
......@@ -284,9 +240,6 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod,
TVM_REGISTER_GLOBAL("target.Build")
.set_body_typed(Build);
TVM_REGISTER_GLOBAL("testing.LoweredFuncsToIRModule")
.set_body_typed(ToIRModule);
// Export two auxiliary function to the runtime namespace.
TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC")
.set_body_typed(PackImportsToC);
......
......@@ -448,7 +448,7 @@ CodeGenLLVM::CreateDebugInfo(llvm::Module* module) {
auto debug_info = llvm::make_unique<CodeGenLLVM::DebugInfo>();
debug_info->di_builder_ = llvm::make_unique<llvm::DIBuilder>(*module);
#endif
// TODO(tulloch): pass this information through relay::Span classes to the LoweredFunc instance?
// TODO(tulloch): pass this information through relay::Span classes to the IRModule instance?
debug_info->file_ = debug_info->di_builder_->createFile("model.tvm", "/tmp/");
debug_info->compilation_unit_ = debug_info->di_builder_->createCompileUnit(
llvm::dwarf::DW_LANG_C, debug_info->file_, "TVM", 0, "", 0, "",
......
......@@ -67,20 +67,23 @@ class LLVMModuleNode final : public runtime::ModuleNode {
} else if (name == "_get_target_triple") {
std::string target_triple = tm_->getTargetTriple().str();
return PackedFunc([target_triple](TVMArgs args, TVMRetValue *rv) {
* rv = target_triple;
*rv = target_triple;
});
}
if (ee_ == nullptr) LazyInitJIT();
// This LLVMModule is empty and no function can be retrieved.
if (entry_func_.empty()) return nullptr;
std::lock_guard<std::mutex> lock(mutex_);
const std::string& fname = (name == runtime::symbol::tvm_module_main ?
entry_func_ : name);
TVMBackendPackedCFunc faddr =
reinterpret_cast<TVMBackendPackedCFunc>(GetFunctionAddr(fname));
TVMBackendPackedCFunc faddr;
if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
GetGlobalAddr(runtime::symbol::tvm_module_main));
CHECK(entry_name != nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<TVMBackendPackedCFunc>(GetFunctionAddr(entry_name));
} else {
faddr = reinterpret_cast<TVMBackendPackedCFunc>(GetFunctionAddr(name));
}
if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self);
}
......@@ -205,6 +208,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_.get());
std::vector<PrimFunc> funcs;
std::string entry_func;
for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs";
......@@ -212,7 +216,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined());
entry_func_ = global_symbol;
entry_func = global_symbol;
}
funcs.push_back(f);
}
......@@ -225,8 +229,8 @@ class LLVMModuleNode final : public runtime::ModuleNode {
cg->AddFunction(f);
}
if (entry_func_.length() != 0) {
cg->AddMainFunction(entry_func_);
if (entry_func.length() != 0) {
cg->AddMainFunction(entry_func);
}
module_ = cg->Finish();
......@@ -321,13 +325,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
CHECK(ee_ != nullptr)
<< "Failed to initialize jit engine for " << mptr_->getTargetTriple();
ee_->runStaticConstructorsDestructors(false);
// setup context address.
// we will skip context setup if this LLVMModule is empty.
if (GetGlobalAddr(runtime::symbol::tvm_module_main) == 0)
return;
entry_func_ =
reinterpret_cast<const char*>(GetGlobalAddr(runtime::symbol::tvm_module_main));
if (void** ctx_addr = reinterpret_cast<void**>(
GetGlobalAddr(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = this;
......@@ -356,8 +354,6 @@ class LLVMModuleNode final : public runtime::ModuleNode {
// The target configuration string
std::string target_;
// Name of entry function.
std::string entry_func_;
// JIT lock
std::mutex mutex_;
// execution engine
......
......@@ -29,7 +29,6 @@
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/runtime/container.h>
#include <string>
#include <vector>
......
......@@ -27,7 +27,6 @@
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/lowered_func.h>
#include <vector>
#include <memory>
......
......@@ -26,7 +26,6 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/target/codegen.h>
#include <string>
#include <vector>
......
......@@ -22,8 +22,10 @@
* \brief Pass to check if memory accesses are legal.
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/target/target.h>
#include <tvm/runtime/registry.h>
namespace tvm {
......@@ -44,7 +46,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
public:
/// Special member functions
//@{
explicit MemoryAccessVerifier(LoweredFunc f, int device_type)
explicit MemoryAccessVerifier(PrimFunc f, int device_type)
: func_(f), dev_type_(device_type) {}
virtual ~MemoryAccessVerifier() = default;
MemoryAccessVerifier(const MemoryAccessVerifier &) = delete;
......@@ -116,7 +118,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
CHECK(V) << "Invalid Variable\n";
// Variable is from function args. Return true.
if (V == func_->args[0].get()) return true;
if (V == func_->params[0].get()) return true;
// The value is expected to come from a tvm_struct_get Call.
// Get the first argument of tvm_struct_get, and continue.
......@@ -179,18 +181,33 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
const ProducerConsumerNode *pc_{nullptr};
bool failure_{false}; ///< If the verification fails (i.e. has illegal access)
//@}
LoweredFunc func_{nullptr}; ///< Function to be verified.
tir::PrimFunc func_{nullptr}; ///< Function to be verified.
int dev_type_{kDLCPU}; ///< Device type
std::unordered_map<const VarNode *, PrimExpr> defs_; ///< Variable definitions
};
} // namespace
/// Interface of VerifyMemory pass
bool VerifyMemory(LoweredFunc func, int device_type) {
MemoryAccessVerifier v(func, device_type);
void VerifyMemory(const IRModule& mod) {
for (auto kv : mod->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
PrimFunc func = GetRef<PrimFunc>(n);
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "LowerWarpMemory: Require the target attribute";
MemoryAccessVerifier v(func, target->device_type);
v.Run();
return !v.Failed();
if (v.Failed()) {
LOG(FATAL)
<< "ValueError: Direct host side access to device memory is detected."
<< " Did you forget to bind?\n"
<< func;
}
}
}
}
TVM_REGISTER_GLOBAL("tir.analysis.verify_memory")
.set_body_typed(VerifyMemory);
} // namespace tir
} // namespace tvm
......@@ -48,7 +48,7 @@ Buffer decl_buffer(Array<PrimExpr> shape,
DataType dtype,
std::string name) {
return BufferNode::make(
Var(name, DataType::Handle()),
Var(name, PointerType(PrimType(dtype))),
dtype,
shape,
Array<PrimExpr>(),
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file lowered_func.cc
*/
#include <tvm/tir/lowered_func.h>
namespace tvm {
namespace tir {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LoweredFuncNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LoweredFuncNode*>(node.get());
p->stream << "LoweredFunc(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(LoweredFuncNode);
} // namespace tir
} // namespace tvm
......@@ -105,13 +105,6 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
});
});
TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccess")
.set_body([](TVMArgs args, TVMRetValue *ret) {
LoweredFunc f = args[0];
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = LowerStorageAccessInfo(f->body);
*ret = LoweredFunc(n);
});
// make from two arguments
#define REGISTER_PASS(PassName) \
......@@ -128,7 +121,6 @@ REGISTER_PASS(VectorizeLoop);
REGISTER_PASS(SkipVectorize);
REGISTER_PASS(UnrollLoop);
REGISTER_PASS(InjectCopyIntrin);
REGISTER_PASS(MakeAPI);
REGISTER_PASS(StorageRewrite);
REGISTER_PASS(CoProcSync);
REGISTER_PASS(LowerStorageAccessInfo);
......@@ -138,9 +130,6 @@ REGISTER_PASS(InjectDoubleBuffer);
REGISTER_PASS(LoopPartition);
REGISTER_PASS(RemoveNoOp);
REGISTER_PASS(LiftAttrScope);
REGISTER_PASS(RemapThreadAxis);
REGISTER_PASS(LowerCustomDatatypes);
REGISTER_PASS(VerifyMemory);
REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(InstrumentBoundCheckers);
......
......@@ -994,29 +994,6 @@ class VectorAllocRewriter : public StmtExprMutator {
};
LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
VectorAllocRewriter rewriter;
n->body = rewriter(n->body);
for (Var arg : f->args) {
if (arg.dtype().is_handle()) {
const auto& tvec = rewriter.acc_map_[arg.get()];
if (tvec.size() == 1) {
PrimExpr dtype = make_const(tvec[0], 0);
n->handle_data_type.Set(arg, dtype);
} else {
// always set data type to be non vectorized so
// load/store can still work via scalarization
if (tvec.size() != 0 && !n->handle_data_type.count(arg)) {
PrimExpr dtype = make_const(tvec[0].with_lanes(1), 0);
n->handle_data_type.Set(arg, dtype);
}
}
}
}
return LoweredFunc(n);
}
PrimFunc PointerValueTypeRewrite(PrimFunc f) {
auto* n = f.CopyOnWrite();
VectorAllocRewriter rewriter;
......
......@@ -22,7 +22,9 @@
*/
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/target/target.h>
#include <tvm/runtime/registry.h>
#include "../../target/datatype/registry.h"
namespace tvm {
......@@ -129,11 +131,26 @@ class CustomDatatypesLowerer : public StmtExprMutator {
std::string target_;
};
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = CustomDatatypesLowerer(target)(n->body);
return LoweredFunc(n);
namespace transform {
Pass LowerCustomDatatypes() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "LowerCustomDatatypes: Require the target attribute";
n->body = CustomDatatypesLowerer(target->target_name)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {});
}
TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes")
.set_body_typed(LowerCustomDatatypes);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -18,20 +18,24 @@
*/
/*!
* \file make_api.cc Build API function.
* \file make_packed_api.cc Lower PrimFunc to use the packed function API.
*/
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/buffer.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <vector>
#include <utility>
#include <unordered_set>
#include "ir_util.h"
#include "arg_binder.h"
#include "../pass/ir_util.h"
#include "../pass/arg_binder.h"
namespace tvm {
namespace tir {
......@@ -40,14 +44,18 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0));
}
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<ObjectRef> api_args,
int num_unpacked_args,
bool is_restricted) {
PrimFunc MakePackedAPI(PrimFunc&& func,
int num_unpacked_args) {
auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined())
<< "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
std::string name_hint = global_symbol;
auto* func_ptr = func.CopyOnWrite();
const Stmt nop = EvaluateNode::make(0);
int num_args = static_cast<int>(api_args.size());
int num_args = static_cast<int>(func_ptr->params.size());
CHECK_LE(num_unpacked_args, num_args);
int num_packed_args = num_args - num_unpacked_args;
// Data field definitions
// The packed fields
......@@ -69,7 +77,8 @@ LoweredFunc MakeAPI(Stmt body,
// local function definitions
// load i-th argument as type t
auto f_arg_value = [&](DataType t, int i) {
Array<PrimExpr> call_args{v_packed_args,
Array<PrimExpr> call_args{
v_packed_args,
IntImm(DataType::Int(32), i),
IntImm(DataType::Int(32), intrinsic::kTVMValueContent)};
// load 64 bit version
......@@ -83,13 +92,7 @@ LoweredFunc MakeAPI(Stmt body,
}
return res;
};
// get declaration of argument i
auto f_arg_decl = [&](int i) {
std::ostringstream os;
os << "arg" << i;
const VarNode* v = api_args[i].as<VarNode>();
return Var(os.str(), v ? v->dtype: DataType::Handle());
};
// ---------------------------
// start of logics
// add signiture for packed arguments.
......@@ -99,16 +102,25 @@ LoweredFunc MakeAPI(Stmt body,
args.push_back(v_num_packed_args);
std::ostringstream os;
os << name << ": num_args should be " << num_packed_args;
os << name_hint << ": num_args should be " << num_packed_args;
seq_init.emplace_back(
MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
}
// Save the input variables and buffers that will be bound later.
std::vector<std::pair<Var, Var> > var_defs;
std::vector<std::pair<Buffer, Var> > buf_defs;
for (int i = 0; i < static_cast<int>(api_args.size()); ++i) {
Var v_arg = f_arg_decl(i);
// Need to re-declare vars, in case some arguments also appears in the buffer.
std::vector<std::pair<Var, Var> > var_def;
std::vector<std::pair<Var, Buffer> > buffer_def;
for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
Var param = func_ptr->params[i];
Var v_arg = Var("arg" + std::to_string(i), param->dtype);
auto it = func_ptr->buffer_map.find(param);
if (it != func_ptr->buffer_map.end()) {
buffer_def.emplace_back(v_arg, (*it).second);
} else {
var_def.emplace_back(v_arg, param);
}
if (i < num_packed_args) {
// Value loads
seq_init.emplace_back(LetStmtNode::make(
......@@ -123,7 +135,7 @@ LoweredFunc MakeAPI(Stmt body,
DataType t = v_arg.dtype();
if (t.is_handle()) {
std::ostringstream msg;
msg << name << ": Expect arg[" << i << "] to be pointer";
msg << name_hint << ": Expect arg[" << i << "] to be pointer";
seq_check.emplace_back(
AssertStmtNode::make(tcode == kTVMOpaqueHandle ||
tcode == kTVMNDArrayHandle ||
......@@ -131,27 +143,18 @@ LoweredFunc MakeAPI(Stmt body,
tcode == kTVMNullptr, msg.str(), nop));
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << name << ": Expect arg[" << i << "] to be int";
msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_check.emplace_back(AssertStmtNode::make(tcode == kDLInt, msg.str(), nop));
} else {
CHECK(t.is_float());
std::ostringstream msg;
msg << name << ": Expect arg[" << i << "] to be float";
msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_check.emplace_back(
AssertStmtNode::make(tcode == kDLFloat, msg.str(), nop));
}
} else {
args.push_back(v_arg);
}
// add checks for functions.
if (api_args[i].as<VarNode>()) {
var_defs.emplace_back(std::make_pair(Downcast<Var>(api_args[i]), v_arg));
} else {
// Buffer checks
CHECK(api_args[i].as<BufferNode>())
<< "api_args can only be Buffer or Var";
buf_defs.emplace_back(std::make_pair(Downcast<Buffer>(api_args[i]), v_arg));
}
}
// allow return value if the function is packed.
......@@ -170,24 +173,22 @@ LoweredFunc MakeAPI(Stmt body,
// either 0 or the original stride will be correctly used. Checks here have
// to use the args that may have no let bining yet. Therefore, hoisting let
// binding for args before buffer declaration is needed.
for (const auto& arg : var_defs) {
binder.Bind(arg.first, arg.second, arg.second->name_hint, true);
for (const auto& kv : var_def) {
binder.Bind(kv.second, kv.first, kv.first->name_hint, true);
}
for (const auto& kv : buffer_def) {
binder.BindDLTensor(kv.second, device_type, device_id,
kv.first, kv.first->name_hint);
}
for (const auto& buf_arg : buf_defs) {
binder.BindDLTensor(buf_arg.first, device_type, device_id,
buf_arg.second, buf_arg.second->name_hint);
if (num_unpacked_args == 0) {
func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
}
ObjectPtr<LoweredFuncNode> n = make_object<LoweredFuncNode>();
n->name = name;
n->args = args;
n->handle_data_type = binder.def_handle_dtype();
n->is_packed_func = num_unpacked_args == 0;
n->is_restricted = is_restricted;
body = AttrStmtNode::make(
auto body = AttrStmtNode::make(
make_zero(DataType::Int(32)), attr::compute_scope,
StringImmNode::make(name + "_compute_"), body);
StringImmNode::make(name_hint + "_compute_"), func_ptr->body);
// Set device context
if (vmap.count(device_id.get())) {
PrimExpr node = StringImmNode::make("default");
......@@ -203,21 +204,59 @@ LoweredFunc MakeAPI(Stmt body,
device_type, device_id}, CallNode::Intrinsic)));
body = SeqStmt({set_device, body});
}
n->body = MergeNest(
func_ptr->body = MergeNest(
{seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
LoweredFunc f(n);
Array<Var> undefined = UndefinedVars(f->body, f->args);
func_ptr->params = args;
Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
if (undefined.size() != 0) {
std::ostringstream os;
for (Var v : undefined) {
os << " \'" << v->name_hint << "\' ";
}
os << " does not appear in api_args";
os << " is not bound to any variables";
LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
}
return f;
func_ptr->buffer_map = Map<Var, Buffer>();
func_ptr->checked_type_ = func_ptr->func_type_annotation();
func_ptr->ret_type = PrimType(DataType::Int(32));
// return the function.
return std::move(func);
}
namespace transform {
Pass MakePackedAPI(int num_unpacked_args) {
auto pass_func = [num_unpacked_args](IRModule m, PassContext ctx) {
IRModuleNode* mptr = m.CopyOnWrite();
std::vector<std::pair<GlobalVar, PrimFunc> > updates;
for (const auto& kv : mptr->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
PrimFunc func = GetRef<PrimFunc>(n);
if (func->GetAttr<Integer>(tvm::attr::kCallingConv, 0)->value
== static_cast<int>(CallingConv::kDefault)) {
auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args);
updates.push_back({kv.first, updated_func});
}
}
}
for (const auto& pair : updates) {
mptr->AddUnchecked(pair.first, pair.second);
}
return m;
};
return tvm::transform::CreateModulePass(
pass_func, 0, "tir.MakePackedAPI", {});
}
TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI")
.set_body_typed(MakePackedAPI);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -22,7 +22,8 @@
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/transform.h>
#include <tvm/runtime/registry.h>
#include <unordered_map>
......@@ -74,8 +75,8 @@ class ThreadAxisRewriter : private StmtExprMutator {
std::unordered_map<const VarNode*, Var> vmap_;
};
LoweredFunc
RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> thread_map) {
PrimFunc RemapThreadAxis(PrimFunc&& f, Map<PrimExpr, IterVar> thread_map) {
std::unordered_map<std::string, IterVar> tmap;
for (const auto& kv : thread_map) {
const StringImmNode* str = kv.first.as<StringImmNode>();
......@@ -83,18 +84,33 @@ RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> thread_map) {
tmap[str->value] = kv.second;
}
CHECK_EQ(f->func_type, kDeviceFunc);
auto n = make_object<LoweredFuncNode>(*f.operator->());
auto thread_axis = f->GetAttr<Array<IterVar> >(tir::attr::kDeviceThreadAxis);
auto* n = f.CopyOnWrite();
// replace the thread axis
for (size_t i = 0; i < n->thread_axis.size(); ++i) {
auto it = tmap.find(n->thread_axis[i]->thread_tag);
for (size_t i = 0; i < thread_axis.size(); ++i) {
auto it = tmap.find(thread_axis[i]->thread_tag);
if (it != tmap.end()) {
n->thread_axis.Set(i, it->second);
thread_axis.Set(i, it->second);
}
}
n->body = ThreadAxisRewriter(tmap).Rewrite(n->body);
return LoweredFunc(n);
n->body = ThreadAxisRewriter(tmap).Rewrite(std::move(n->body));
return WithAttr(std::move(f), tir::attr::kDeviceThreadAxis, thread_axis);
}
namespace transform {
Pass RemapThreadAxis(Map<PrimExpr, IterVar> thread_map) {
auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) {
return RemapThreadAxis(std::move(f), thread_map);
};
return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis", {});
}
TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis")
.set_body_typed(RemapThreadAxis);
} // namespace transform
} // namespace tir
} // namespace tvm
......@@ -264,7 +264,6 @@ class HostDeviceSplitter : public StmtMutator {
std::string name_prefix_;
// Number of device functions.
int device_func_counter_{0};
std::vector<LoweredFunc> device_funcs_;
std::unordered_map<const VarNode*, PrimExpr> handle_data_type_;
};
......
......@@ -117,7 +117,7 @@ TEST(BuildModule, Heterogeneous) {
std::unordered_map<Tensor, Buffer> binds;
auto lowered_s1 = lower(s1, args1, "elemwise_add", binds, config);
auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds, config);
Map<tvm::Target, Array<LoweredFunc>> inputs = {{target_cuda, lowered_s1},
Map<tvm::Target, IRModule> inputs = {{target_cuda, lowered_s1},
{target_llvm, lowered_s2}};
auto module = build(inputs, Target(), config);
......
......@@ -18,29 +18,6 @@ import tvm
from tvm import te
import numpy as np
def lower(s, args, name="mydot"):
binds = {}
arg_list = []
for x in args:
assert isinstance(x, te.tensor.Tensor)
buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name)
binds[x] = buf
arg_list.append(buf)
s = s.normalize()
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 16)
stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.tir.ir_pass.Simplify(stmt)
fapi = tvm.tir.ir_pass.MakeAPI(stmt, name, arg_list, 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi)
return fapi
def mybuild(fapi, target="llvm"):
return
def test_dot():
nn = 12
......
......@@ -38,8 +38,9 @@ def test_dltensor_compatible():
with ib.for_range(0, n - 1, "i") as i:
A[i + 1] = A[i] + 1
stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True)
mod = tvm.testing.LoweredFuncsToIRModule([fapi])
mod = tvm.testing.MakeAPILegacy(stmt, "arange", [Ab], 0, True)
mod = tvm.tir.transform.LowerTVMBuiltin()(mod)
f = tvm.target.codegen.build_module(mod, "stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
......
......@@ -156,7 +156,7 @@ def test_simplex_data_transferring():
elemwise_sub],
name="elemwise_sub")
target_flist = {target_device: [lower_add], target_host: [lower_sub]}
target_flist = {target_device: lower_add, target_host: lower_sub}
mhost = tvm.build(target_flist, target_host=target_host)
ctx = [host_ctx, device_ctx]
mod = graph_runtime.create(graph, mhost, ctx)
......@@ -354,8 +354,9 @@ def test_duplex_data_transferring():
elemwise_sub],
name="elemwise_sub")
target_flist = {target_device: [lower_add0, lower_add1], target_host:
[lower_sub]}
lower_add0.update(lower_add1)
target_flist = {target_device: lower_add0, target_host:
lower_sub}
mhost = tvm.build(target_flist, target_host=target_host)
ctx = [host_ctx, device_ctx]
params = {}
......
......@@ -57,8 +57,8 @@ def test_dso_module_load():
tvm.tir.Store(Ab.data,
tvm.tir.Load(dtype, Ab.data, i) + 1,
i + 1))
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
m = tvm.driver.build(fapi, target="llvm")
m = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
m = tvm.driver.build(m, target="llvm")
for name in names:
m.save(name)
......
......@@ -22,6 +22,7 @@ import numpy as np
import ctypes
import math
def test_llvm_intrin():
ib = tvm.tir.ir_builder.create()
n = tvm.runtime.convert(4)
......@@ -34,7 +35,8 @@ def test_llvm_intrin():
tvm.tir.Call(
"int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0)))
body = ib.get()
func = tvm.tir.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
func = tvm.testing.MakeAPILegacy(body, "prefetch", [A], 0, True)
fcode = tvm.build(func, None, "llvm")
......@@ -85,7 +87,7 @@ def test_llvm_lookup_intrin():
x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.tir.const(1, 'uint32'), A)
ib.emit(x)
body = ib.get()
func = tvm.tir.ir_pass.MakeAPI(body, "ctpop", [A], 1, True)
func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 1, True)
fcode = tvm.build(func, None, "llvm")
......@@ -307,8 +309,9 @@ def test_multiple_func():
f2 = tvm.lower(s, [A, B, C], name="fadd1")
f1 = tvm.lower(s, [A, B, C], name="fadd2")
m = tvm.build([f1, f2], "llvm")
fadd1 = m['fadd1']
fadd2 = m['fadd2']
fadd1 = m['fadd1']
ctx = tvm.cpu(0)
# launch the kernel.
n = nn
......@@ -665,6 +668,7 @@ def test_llvm_shuffle():
tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))
if __name__ == "__main__":
test_multiple_func()
test_llvm_large_uintimm()
test_llvm_import()
test_alignment()
......@@ -676,7 +680,6 @@ if __name__ == "__main__":
test_llvm_vadd_pipeline()
test_llvm_add_pipeline()
test_llvm_intrin()
test_multiple_func()
test_llvm_flip_pipeline()
test_llvm_madd_pipeline()
test_llvm_temp_space()
......
......@@ -19,6 +19,18 @@ from tvm import te
import ctypes
import numpy as np
def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to create a API"""
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.no_alias", True)
mod = tvm.IRModule.from_expr(f)
return tvm.tir.transform.MakePackedAPI()(mod)
def test_static_callback():
dtype = 'int64'
n = te.size_var('n')
......@@ -32,7 +44,7 @@ def test_static_callback():
with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1
stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
f = tvm.driver.build(fapi, target="llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
......@@ -55,7 +67,7 @@ def test_static_init():
return sh
stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
f = tvm.driver.build(fapi, target="llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
......
......@@ -26,6 +26,18 @@ def run_jit(fapi, check):
s = f.get_source()
check(f)
def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to create a API"""
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.no_alias", True)
mod = tvm.IRModule.from_expr(f)
return tvm.tir.transform.MakePackedAPI()(mod)
def test_stack_vm_basic():
a = tvm.nd.array(np.zeros(10, dtype='float32'))
@tvm.register_func
......@@ -36,7 +48,7 @@ def test_stack_vm_basic():
n = te.size_var('n')
Ab = tvm.tir.decl_buffer((n, ), "float32")
stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True)
fapi = tvm.testing.MakeAPILegacy(stmt, "print_shape", [Ab], 0, True)
run_jit(fapi, lambda f: f(a))
......@@ -57,7 +69,7 @@ def test_stack_vm_loop():
ib.emit(tvm.tir.call_packed("tvm_stack_vm_print", i))
stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f):
f(a)
......@@ -79,7 +91,7 @@ def test_stack_vm_cond():
A[i + 1] = A[i] + 2
stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab], 0, True)
fapi = tvm.testing.MakeAPILegacy(stmt, "test", [Ab], 0, True)
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
......@@ -98,7 +110,7 @@ def test_vm_parallel():
with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1
stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True)
fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
......
......@@ -19,7 +19,6 @@ import tvm
from tvm import te
from ctypes import *
import topi
import tvm.tir.ir_pass as ir_pass
import numpy as np
tgt = "llvm"
......@@ -51,10 +50,12 @@ def lower_datatypes_and_build(schedule, args):
Once datatype lowering is integrated directly into TVM's lower/build
process, we won't need to do this manually.
TODO(gus) integrate datatype lowering into build process; change this test"""
flist = tvm.lower(schedule, args)
flist = [flist]
flist = [ir_pass.LowerCustomDatatypes(func, tgt) for func in flist]
return tvm.build(flist[0], target=tgt)
mod = tvm.lower(schedule, args)
target = tvm.target.create(tgt)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
mod = tvm.tir.transform.LowerCustomDatatypes()(mod)
return tvm.build(mod, target=tgt)
def test_bfloat_add_and_cast_1():
X = te.placeholder((3, ), name="X")
......
......@@ -15,12 +15,13 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import pytest
from tvm import te
# The following DLDeviceType/TVMDeviceExtType values
# are originally defined in dlpack.h and c_runtime_api.h.
gpu_devices = [2, 4, 7, 8, 10, 11]
other_devices = [1, 3, 9, 12]
gpu_devices = ["cuda", "opencl", "metal", "vulkan"]
other_devices = ["llvm", "ext_dev"]
def lower(sch, args):
......@@ -39,8 +40,11 @@ def lower(sch, args):
stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64)
func = tvm.tir.ir_pass.MakeAPI(stmt, "myadd", arg_list, 0, True)
return func
f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
"global_symbol", tvm.runtime.String("test"))
mod = tvm.IRModule({"test": f})
return tvm.tir.transform.MakePackedAPI()(mod)
# All computations are bound.
......@@ -57,10 +61,13 @@ def test_verify_memory_all_bind():
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))
func = lower(s, [A, B])
mod = lower(s, [A, B])
for dev_type in gpu_devices + other_devices:
assert tvm.tir.ir_pass.VerifyMemory(func, dev_type)
binded_mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
tvm.tir.analysis.verify_memory(binded_mod)
# Computations are not bound.
......@@ -74,12 +81,18 @@ def test_verify_memory_not_bind():
# B is not bound to threads.
s = te.create_schedule(B.op)
func = lower(s, [A, B])
mod = lower(s, [A, B])
for dev_type in gpu_devices:
assert not tvm.tir.ir_pass.VerifyMemory(func, dev_type)
binded_mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
with pytest.raises(ValueError):
tvm.tir.analysis.verify_memory(binded_mod)
for dev_type in other_devices:
assert tvm.tir.ir_pass.VerifyMemory(func, dev_type)
binded_mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
tvm.tir.analysis.verify_memory(binded_mod)
# Computations are partially bound.
......@@ -98,16 +111,22 @@ def test_verify_memory_partially_bind():
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
func = lower(s, [A, B, C, D])
mod = lower(s, [A, B, C, D])
for dev_type in gpu_devices:
assert not tvm.tir.ir_pass.VerifyMemory(func, dev_type)
binded_mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
with pytest.raises(ValueError):
tvm.tir.analysis.verify_memory(binded_mod)
for dev_type in other_devices:
assert tvm.tir.ir_pass.VerifyMemory(func, dev_type)
binded_mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
tvm.tir.analysis.verify_memory(binded_mod)
if __name__ == "__main__":
test_verify_memory_all_bind()
test_verify_memory_not_bind()
test_verify_memory_partially_bind()
......@@ -118,7 +118,6 @@ def test_in_bounds_vectorize_llvm():
s[B].vectorize(xi)
# build and invoke the kernel.
lowered_func = tvm.lower (s, [A, C], "llvm", simple_mode=False)
print (lowered_func.body)
f = tvm.build(s, [A, C], "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
......@@ -137,7 +136,6 @@ def test_in_bounds_loop_partition_basic_llvm():
s = te.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
......@@ -156,7 +154,6 @@ def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b):
s = te.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
......@@ -205,12 +202,11 @@ def test_in_bounds_const_loop_partition_ir():
# after instrumentation
assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3)
assert_bound_instrumentation(stmt, check_branch_stmt, 2)
print (stmt)
branch_collector = list()
collect_visit(stmt, collect_branch_stmt)
assert(len(branch_collector) == 2)
print (branch_collector[0].condition)
print (branch_collector[1].condition)
def test_in_bounds_const_loop_partition_llvm():
with tvm.target.build_config(instrument_bound_checkers=True, partition_const_loop=True):
......@@ -222,7 +218,6 @@ def test_in_bounds_const_loop_partition_llvm():
s = te.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
......@@ -242,7 +237,6 @@ def test_out_of_bounds_const_loop_partition_llvm(index_a, index_b):
s = te.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
......@@ -276,7 +270,6 @@ def test_in_bounds_conv_llvm(loop_tiling=False):
if loop_tiling:
oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True)
print (lowered_func.body)
ctx = tvm.cpu (0)
f = tvm.build(s, [data, kernel, conv], "llvm")
......@@ -320,7 +313,6 @@ def test_out_of_bounds_conv_llvm(data_offsets, kernel_offsets, loop_tiling=False
if loop_tiling:
oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True)
print (lowered_func.body)
ctx = tvm.cpu (0)
f = tvm.build(s, [data, kernel, conv], "llvm")
......@@ -341,7 +333,6 @@ def test_in_bounds_tensors_with_same_shapes1D_llvm():
T = te.compute((m, ), lambda i: A[i]*B[i])
s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
......@@ -361,7 +352,6 @@ def test_out_of_bounds_tensors_with_diff_shapes1D_llvm(a_shape, b_shape, c_shape
T = te.compute((m, ), lambda i: A[i]*B[i])
s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
......@@ -380,7 +370,6 @@ def test_in_bounds_tensors_with_same_shapes2D_llvm():
T = te.compute((m, m), lambda i, j: A[i][j]*B[i][j])
s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
......@@ -400,7 +389,6 @@ def test_out_of_bounds_tensors_with_diff_shapes2D_llvm(a_shape, b_shape, c_shape
T = te.compute((m, m), lambda i, j: A[i][j]*B[i][j])
s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
......@@ -419,7 +407,7 @@ def test_in_bounds_tensors_with_same_shapes3D_llvm():
T = te.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p])
s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
......@@ -439,7 +427,7 @@ def test_out_of_bounds_tensors_with_diff_shapes3D_llvm(a_shape, b_shape, c_shape
T = te.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p])
s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm")
......@@ -460,7 +448,7 @@ def test_out_of_bounds_tensors_with_zero_shape_op_with_not_zero_shape_llvm():
D = te.compute((), lambda : C + 1)
s = te.create_schedule(D.op)
stmt = tvm.lower (s, [A, scale, D], simple_mode=True)
print (stmt)
# build and invoke the kernel.
f = tvm.build(s, [A, scale, D], "llvm")
ctx = tvm.cpu(0)
......
......@@ -40,8 +40,7 @@ def test_double_buffer():
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2
f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
mod = tvm.testing.LoweredFuncsToIRModule([f])
mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 2, True)
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0]
......
......@@ -381,7 +381,7 @@ def test_multilevel_splitting_with_indivisble_factors():
## But this does the right thing.
with tvm.target.build_config(partition_const_loop=True):
lowered_body = tvm.lower(s, [A, B]).body
lowered_body = tvm.lower(s, [A, B], name="x")["x"].body
def visit_stmt(op):
return(isinstance(op, tvm.tir.Max))
num_max = collect_visit(lowered_body, visit_stmt)
......@@ -407,7 +407,7 @@ def test_double_splitting_with_indivisible_factors():
# Find the beginning of the Halide IR corresponding to kernel code
# and make sure it doesn't have an if statements left
top_produce = find_top_produce(f.body)
top_produce = find_top_produce(f["fadd1"].body)
assert(not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse))))
# check functional correctness of generated code
......
......@@ -92,9 +92,7 @@ def test_flatten_double_buffer():
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2
f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
mod = tvm.testing.LoweredFuncsToIRModule([f])
mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 2, True)
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0]
......
......@@ -36,12 +36,7 @@ def test_for():
ib.emit(tvm.tir.call_extern
("int32", "fadd", device_context(0), A))
body = ib.get()
f = tvm.tir.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True)
# temp adapter to convert loweredFunc to IRModule
# to test passes in the new style.x
mod = tvm.testing.LoweredFuncsToIRModule([f])
mod = tvm.testing.MakeAPILegacy(body, "func", [dev_type, n], 2, True)
mod = tvm.tir.transform.CombineContextCall()(mod)
assert mod["func"].body.value.dtype == "handle"
......
......@@ -35,10 +35,8 @@ def test_lower_warp_mem():
cuda_target = tvm.target.create("cuda")
assert cuda_target.thread_warp_size == 32
f = tvm.lower(s, [A, B], name="f")
mod = tvm.lower(s, [A, B], name="f")
mod = tvm.testing.LoweredFuncsToIRModule([f])
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
mod = tvm.IRModule.from_expr(fdevice)
......
......@@ -35,11 +35,11 @@ def test_makeapi():
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64)
num_unpacked_args = 2
f = tvm.tir.ir_pass.MakeAPI(
stmt, "myadd", [n, Ab, Bb, Cb], num_unpacked_args, True)
assert(f.handle_data_type[Ab.data].dtype == Ab.dtype)
assert(len(f.args) == 7)
output_ssa = False
f = tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr(
"tir.no_alias", True).with_attr("global_symbol", tvm.runtime.String("myadd"))
mod = tvm.IRModule.from_expr(f)
f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"]
assert(len(f.params) == 7)
if __name__ == "__main__":
......
......@@ -37,10 +37,9 @@ def test_thread_storage_sync():
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
cuda_target = tvm.target.create("cuda")
mod = tvm.testing.LoweredFuncsToIRModule([f])
cuda_target = tvm.target.create("cuda")
mod = tvm.testing.MakeAPILegacy(stmt, "test", [Ab, A2b], 0, True)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
mod = tvm.IRModule.from_expr(fdevice)
......
......@@ -36,7 +36,7 @@ Before reading this tutorial, we assume readers have already known these topics
- Visitor design pattern. Otherwise, check the
`Python AST module <https://docs.python.org/3/library/ast.html>`_ to see how an AST
visitor is implemented.
- How a HalideIR/Schedule is lowered to either a LoweredFunc class or a LLVM module. Otherwise,
- How a Schedule is lowered to either an IRModule class or a LLVM module. Otherwise,
take a look at ``python/tvm/build_module.py`` to get some basics.
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment