Unverified Commit e63e08fe by Tianqi Chen Committed by GitHub

[REFACTOR][TIR] Migrate all low-level passes to the Pass Manager. (#5233)

* [REFACTOR][TIR] Migrate all low-level passes to the Pass Manager.

This PR migrates the tvm.lower to return IRModule of PrimFuncs
instead of the LoweredFuncs.

* Remove LoweredFunc.
parent fd9ce583
...@@ -46,7 +46,6 @@ def __lldb_init_module(debugger, _): ...@@ -46,7 +46,6 @@ def __lldb_init_module(debugger, _):
"tvm::IterVarAttr", "tvm::IterVarAttr",
"tvm::IterVarRelation", "tvm::IterVarRelation",
"tvm::Layout", "tvm::Layout",
"tir::LoweredFunc",
"tvm::Map", "tvm::Map",
"tvm::Map", "tvm::Map",
"tvm::MemoryInfo", "tvm::MemoryInfo",
......
...@@ -145,15 +145,6 @@ After lowering is done, ``build()`` function generates target machine code from ...@@ -145,15 +145,6 @@ After lowering is done, ``build()`` function generates target machine code from
Code generation is done by ``build_module()`` function, defined in ``python/tvm/target/codegen.py``. On the C++ side, code generation is implemented in ``src/target/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/target/codegen/codegen.cc``: Code generation is done by ``build_module()`` function, defined in ``python/tvm/target/codegen.py``. On the C++ side, code generation is implemented in ``src/target/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/target/codegen/codegen.cc``:
::
runtime::Module Build(const Array<LoweredFunc>& funcs,
const std::string& target) {
std::string build_f_name = "codegen.build_" + target;
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
runtime::Module m = (*bf)(funcs, target);
return m;
}
The ``Build()`` function looks up the code generator for the given target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this: The ``Build()`` function looks up the code generator for the given target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this:
......
...@@ -32,8 +32,8 @@ ...@@ -32,8 +32,8 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h> #include <tvm/target/target.h>
#include <tvm/support/with.h> #include <tvm/support/with.h>
#include <tvm/ir/module.h>
#include <tvm/te/schedule_pass.h> #include <tvm/te/schedule_pass.h>
#include <tvm/tir/lowered_func.h>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -43,15 +43,15 @@ ...@@ -43,15 +43,15 @@
namespace tvm { namespace tvm {
/*! /*!
* \brief Build a LoweredFunc given a schedule, args and binds * \brief Build an IRModule given a schedule, args and binds
* \param sch The schedule to lower. * \param sch The schedule to lower.
* \param args The arguments to the function. * \param args The arguments to the function.
* \param name The name of the lowered function. * \param name The name of the lowered function.
* \param binds Buffer assignments. * \param binds Buffer assignments.
* \param config The build configuration. * \param config The build configuration.
* \return The lowered function. * \return The result module.
*/ */
TVM_DLL Array<tir::LoweredFunc> lower( TVM_DLL IRModule lower(
te::Schedule sch, te::Schedule sch,
const Array<te::Tensor>& args, const Array<te::Tensor>& args,
const std::string& name, const std::string& name,
...@@ -59,44 +59,43 @@ TVM_DLL Array<tir::LoweredFunc> lower( ...@@ -59,44 +59,43 @@ TVM_DLL Array<tir::LoweredFunc> lower(
const BuildConfig& config); const BuildConfig& config);
/*! /*!
* \brief Build a device and host module for a specific target from an array of lowered functions. * \brief Build a device and host module for a specific target from an IRModule.
* \param funcs The functions to be built. * \param funcs The functions to be built.
* \param target The target device to build for. * \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target() * \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration. * \param config The build configuration.
* \return The built module. * \return The built module.
*/ */
TVM_DLL runtime::Module build(const Array<tir::LoweredFunc>& funcs, TVM_DLL runtime::Module build(const IRModule& funcs,
const Target& target, const Target& target,
const Target& target_host, const Target& target_host,
const BuildConfig& config); const BuildConfig& config);
/*! /*!
* \brief Build a device and host module for a specific target from a map * \brief Build a device and host module for a specific target from a map
* contains target to a list of lowered functions pairs. This function is used * contains target to IRModule. This function is used
* for heterogeneous build. * for heterogeneous build.
* \param input The map contains target to a list of lowered functions pairs. * \param input The map contains target to an IRModule.
* \param target_host The target for building host code. To use the default, * \param target_host The target for building host code. To use the default,
* pass Target(). * pass Target().
* \param config The build configuration. * \param config The build configuration.
* \return The built module that contains code for different processors. * \return The built module that contains code for different processors.
*/ */
TVM_DLL runtime::Module build(const Map<Target, Array<tir::LoweredFunc>>& input, TVM_DLL runtime::Module build(const Map<Target, IRModule>& input,
const Target& target_host, const Target& target_host,
const BuildConfig& config); const BuildConfig& config);
/*! /*!
* \brief Build a device and host module for a specific target from a map * \brief Build a device and host module for a specific target from a map
* contains target to a list of lowered functions pairs. This function is used * contains target to IRModule. This function is used
* for heterogeneous build. * for heterogeneous build.
* \param input The map contains target string to a list of lowered functions * \param input The map contains target string to an IRModule.
* pairs.
* \param target_host The target for building host code. To use the default, * \param target_host The target for building host code. To use the default,
* pass Target(). * pass Target().
* \param config The build configuration. * \param config The build configuration.
* \return The built module that contains code for different processors. * \return The built module that contains code for different processors.
*/ */
TVM_DLL runtime::Module build(const Map<std::string, Array<tir::LoweredFunc>>& input, TVM_DLL runtime::Module build(const Map<std::string, IRModule>& input,
const Target& target_host, const Target& target_host,
const BuildConfig& config); const BuildConfig& config);
} // namespace tvm } // namespace tvm
......
...@@ -297,6 +297,15 @@ class IRModule : public ObjectRef { ...@@ -297,6 +297,15 @@ class IRModule : public ObjectRef {
CHECK(ptr != nullptr); CHECK(ptr != nullptr);
return static_cast<IRModuleNode*>(ptr); return static_cast<IRModuleNode*>(ptr);
} }
/*!
* \brief Construct an empty module.
*
* \returns The constructed module
*/
static IRModule Empty() {
return IRModule(Map<GlobalVar, BaseFunc>());
}
/*! /*!
* \brief Construct a module from a standalone expression. * \brief Construct a module from a standalone expression.
* *
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/target/target.h> #include <tvm/target/target.h>
#include <string> #include <string>
...@@ -42,17 +41,6 @@ using runtime::TVMArgs; ...@@ -42,17 +41,6 @@ using runtime::TVMArgs;
using runtime::TVMRetValue; using runtime::TVMRetValue;
/*! /*!
* \brief Temporary backward compatible function to convert a list
* of LoweredFunc to a IRModule of PrimfFuncs
* \param funcs The input lowered function.
* \return The IRModule.
*
* \note This function is only used for code refactor and will be
* removed once the refactor completes.
*/
IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs);
/*!
* \brief Build a module from array of lowered function. * \brief Build a module from array of lowered function.
* \param mod The Module to be built * \param mod The Module to be built
* \param target The target to be built. * \param target The target to be built.
......
...@@ -24,9 +24,12 @@ ...@@ -24,9 +24,12 @@
#ifndef TVM_TIR_ANALYSIS_H_ #ifndef TVM_TIR_ANALYSIS_H_
#define TVM_TIR_ANALYSIS_H_ #define TVM_TIR_ANALYSIS_H_
#include <tvm/ir/module.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h> #include <tvm/tir/stmt.h>
namespace tvm { namespace tvm {
namespace tir { namespace tir {
...@@ -59,6 +62,18 @@ struct ExprDeepEqual { ...@@ -59,6 +62,18 @@ struct ExprDeepEqual {
*/ */
Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs); Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
* In the case that tgt is cuda, if not all workload is bound with
* threads, CPU code is generated that tries to access GPU memory,
* which is illegal. This pass performs verification for this case.
*
* \param mod The module to be verified.
* \return Success of memory verification.
*/
void VerifyMemory(const IRModule& mod);
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
#endif // TVM_TIR_ANALYSIS_H_ #endif // TVM_TIR_ANALYSIS_H_
...@@ -31,7 +31,6 @@ ...@@ -31,7 +31,6 @@
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/buffer.h> #include <tvm/tir/buffer.h>
#include <tvm/tir/function.h> #include <tvm/tir/function.h>
#include <tvm/tir/lowered_func.h>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -367,60 +366,6 @@ Stmt HoistIfThenElse(Stmt stmt); ...@@ -367,60 +366,6 @@ Stmt HoistIfThenElse(Stmt stmt);
Stmt NarrowDataType(Stmt stmt, int target_bits); Stmt NarrowDataType(Stmt stmt, int target_bits);
/*! /*!
* \brief Make an user callable API LoweredFunc.
*
* The main task of this function is to create code to :
* - Map the values in the api_args to Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_unpacked_args Number of arguments that
* are processed in plain form instead of packed form.
* \param is_restricted Whether the caller can guarantee that each buffer argument do not overlap.
* It is recommended to set to true for optimized code if such invariant holds.
*
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signature have two cases
*
* let num_packed_args = len(api_args) - num_unpacked_args;
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n,
* TVMValue* out_ret_val, int* out_ret_tcode)
*
* where n == len(api_args), k == num_packed_args
*
* There is no thread_axis in generated function.
*/
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<ObjectRef> api_args,
int num_unpacked_args,
bool is_restricted);
/*!
* \brief Remap the thread axis
*
* This can be used to get equivalent program which uses
* threadIdx.y in place of threadIdx.x by passing
* {"threadIdx.x": thread_axis("threadIdx.y")}
*
*
* \param f The device function to be lowered.
* \param axis_map The map from StringImm -> ItrVar
* \return Transformed function.
*/
LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
/*!
* \brief Rewrite the pointer content type of arguments, * \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use * as well as Alloc internal to the function to use
* the most frequently accessed type for load/store * the most frequently accessed type for load/store
...@@ -433,31 +378,6 @@ LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map); ...@@ -433,31 +378,6 @@ LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
PrimFunc PointerValueTypeRewrite(PrimFunc f); PrimFunc PointerValueTypeRewrite(PrimFunc f);
/*! /*!
* \brief Lower custom datatypes.
*
* See tvm::datatypes::Registry for more information on adding custom datatypes.
*
* \param f The device function to be lowered.
* \param target The target device.
* \return Transformed function.
*/
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
/*!
* \brief Verify if memory accesses are legal for a specific target device type.
*
* In the case that tgt is cuda, if not all workload is bound with
* threads, CPU code is generated that tries to access GPU memory,
* which is illegal. This pass performs verification for this case.
*
* \param func The function to be verified.
* \param device_type The target device type.
* \return Success of memory verification.
*/
bool VerifyMemory(LoweredFunc func, int device_type);
/*!
* \brief Verify the correctness of a GPU code * \brief Verify the correctness of a GPU code
* It will check the whether the amount of memory usage or the number of threads * It will check the whether the amount of memory usage or the number of threads
* in a block exceeds the limit * in a block exceeds the limit
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/tir/lowered_func.h
* \brief Information about a lowered TVM function.
* This data structure is final step toward codegen.
*/
#ifndef TVM_TIR_LOWERED_FUNC_H_
#define TVM_TIR_LOWERED_FUNC_H_
#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <string>
namespace tvm {
namespace tir {
// Internal node container of lowered function.
class LoweredFuncNode;
/*!
* \brief LoweredFunc represents function after lowering.
* This is the final IR representation before codegen.
*/
class LoweredFunc : public FunctionRef {
public:
LoweredFunc() {}
explicit LoweredFunc(ObjectPtr<Object> n) : FunctionRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const LoweredFuncNode* operator->() const;
/*! \brief specify container node */
using ContainerType = LoweredFuncNode;
};
/*! \brief specific type of lowered function */
enum LoweredFuncType : int {
/*! \brief Function that can mix device and host calls */
kMixedFunc = 0,
/*! \brief Only contains host code */
kHostFunc = 1,
/*! \brief Only contains device code */
kDeviceFunc = 2
};
/*! \brief Node container of LoweredFunc */
class LoweredFuncNode : public tir::FunctionBaseNode {
public:
/*! \brief The name of the function */
std::string name;
/*!
* \brief The arguments of the function
* This function can only take pod type(int, float) and void* as arguments.
*/
Array<Var> args;
/*!
* \brief The IterVar axis of threads
* Each axis need host function to specify a size.
* \note Calling convention into LoweredFunc
*
* Assume we have a LoweredFunc f, a call into f
* Call(f, arg1, arg2, ..., arg_n,
* size_axis_1, size_axis_2, ... size_axis_m)
*
* Here n = len(args), m = len(thread_axis)
*
* The CodeGen should take this and translate this call
* to corresponding API specific kernel launchs or function calls.
*/
Array<IterVar> thread_axis;
/*!
* \brief The hint data type of Var handles defined in LetStmt
* Can be used as hint when generating type signiture.
* The creation rule is given by
* handle_data_type[var_handle] = make_const(the_type, 0);
*
* \note Expr is used instead Type, because Type cannot be hold by Map.
* constant Expr of given type is used.
*/
Map<Var, PrimExpr> handle_data_type;
/*! \brief The type of the function */
LoweredFuncType func_type{kMixedFunc};
/*! \brief Whether this function is packed function */
bool is_packed_func{true};
/*!
* \brief Whether function ensures that argument pointers do not alias.
* This corresponds to restrict keyword in C.
*/
bool is_restricted{true};
/*! \brief The body statment of the function */
Stmt body;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
// there is no return value, but return 1
// to enable Call into this function.
int num_outputs() const final {
return 1;
}
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("args", &args);
v->Visit("thread_axis", &thread_axis);
v->Visit("handle_data_type", &handle_data_type);
v->Visit("func_type", &func_type);
v->Visit("is_packed_func", &is_packed_func);
v->Visit("is_restricted", &is_restricted);
v->Visit("body", &body);
}
static constexpr const char* _type_key = "LoweredFunc";
TVM_DECLARE_FINAL_OBJECT_INFO(LoweredFuncNode, Object);
};
// Implementations of inline functions
inline const LoweredFuncNode* LoweredFunc::operator->() const {
return static_cast<const LoweredFuncNode*>(get());
}
} // namespace tir
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::tir::LoweredFunc> : public tvm::ObjectHash {
};
}
#endif // TVM_TIR_LOWERED_FUNC_H_
...@@ -59,6 +59,61 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< ...@@ -59,6 +59,61 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const tvm::Array<tvm::PrimExpr>& required); const tvm::Array<tvm::PrimExpr>& required);
/*! /*!
* \brief Transform the high-level PrimFunc to a low-level version
* that can be used as an API function.
*
*
* The main task of this function is to create code to :
* - Map the values in the api_args to Var that is required by body.
* - Insert assertions to check type/value of the passed arguments.
*
* \param num_unpacked_args Number of arguments that
* are processed in plain form instead of packed form.
*
* \note
* The function signature have two cases
*
* let num_packed_args = len(api_args) - num_unpacked_args;
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
* if num_packed_args is not zero:
* f(TVMArg* packed_args, int* packed_arg_type_ids, int num_packed_args,
* api_arg_k, api_arg_k+1, ... api_arg_n,
* TVMValue* out_ret_val, int* out_ret_tcode)
*
* where n == len(api_args), k == num_packed_args
*
* \return The pass.
*/
TVM_DLL Pass MakePackedAPI(int num_unpacked_args);
/*!
* \brief Remap the thread axis
*
* This can be used to get equivalent program which uses
* threadIdx.y in place of threadIdx.x by passing
* {"threadIdx.x": thread_axis("threadIdx.y")}
*
*
* \return The pass.
*/
TVM_DLL Pass RemapThreadAxis(Map<PrimExpr, IterVar> axis_map);
/*!
* \brief Lower custom datatypes.
*
* See tvm::datatypes::Registry for more information on adding custom datatypes.
*
* \return The pass.
*/
TVM_DLL Pass LowerCustomDatatypes();
/*!
* \brief Bind the device type ofthe function to be * \brief Bind the device type ofthe function to be
* the device_type specified in the target attribute. * the device_type specified in the target attribute.
* *
......
...@@ -17,9 +17,6 @@ ...@@ -17,9 +17,6 @@
# pylint: disable=invalid-name # pylint: disable=invalid-name
"""The build utils in python. """The build utils in python.
This module provides the functions to transform schedule to
LoweredFunc and compiled Module.
""" """
import warnings import warnings
...@@ -30,7 +27,6 @@ from tvm.ir import container ...@@ -30,7 +27,6 @@ from tvm.ir import container
from tvm.ir import CallingConv from tvm.ir import CallingConv
from tvm.target import codegen, BuildConfig from tvm.target import codegen, BuildConfig
from tvm.tir import ir_pass from tvm.tir import ir_pass
from tvm.tir.stmt import LoweredFunc
from tvm.te import tensor from tvm.te import tensor
from tvm.te import schedule from tvm.te import schedule
from tvm import target as _target from tvm import target as _target
...@@ -136,8 +132,8 @@ def lower(sch, ...@@ -136,8 +132,8 @@ def lower(sch,
Returns Returns
------- -------
f : LoweredFunc or Stmt m : IRModule or Stmt
The result function, if with_api_wrapper=False The result IRModule, if simple_mode=False
Then the Stmt before make api is returned. Then the Stmt before make api is returned.
""" """
cfg = BuildConfig.current() cfg = BuildConfig.current()
...@@ -199,16 +195,21 @@ def lower(sch, ...@@ -199,16 +195,21 @@ def lower(sch,
if simple_mode: if simple_mode:
return stmt return stmt
return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func) f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
if cfg.restricted_func:
f = f.with_attr("tir.no_alias", True)
mod = tvm.IRModule({name: f})
return tvm.tir.transform.MakePackedAPI()(mod)
def _build_for_device(flist, target, target_host): def _build_for_device(input_mod, target, target_host):
"""Build the lowered functions for a device with the given compilation """Build the lowered functions for a device with the given compilation
target. target.
Parameters Parameters
---------- ----------
flist : list of LoweredFunc input_mod : IRModule
The schedule to be built. The schedule to be built.
target : str or :any:`tvm.target.Target` target : str or :any:`tvm.target.Target`
...@@ -219,8 +220,8 @@ def _build_for_device(flist, target, target_host): ...@@ -219,8 +220,8 @@ def _build_for_device(flist, target, target_host):
Returns Returns
------- -------
fhost : list of LoweredFunc fhost : IRModule
A list of lowered functions for the host. The host IRModule.
mdev : tvm.module mdev : tvm.module
A module that contains device code. A module that contains device code.
...@@ -229,14 +230,13 @@ def _build_for_device(flist, target, target_host): ...@@ -229,14 +230,13 @@ def _build_for_device(flist, target, target_host):
target_host = _target.create(target_host) target_host = _target.create(target_host)
device_type = ndarray.context(target.target_name, 0).device_type device_type = ndarray.context(target.target_name, 0).device_type
for func in flist: mod_mixed = input_mod
if not ir_pass.VerifyMemory(func, device_type): mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)
raise ValueError( tvm.tir.analysis.verify_memory(mod_mixed)
"Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name)
mod_mixed = tvm.testing.LoweredFuncsToIRModule(flist) opt_mixed = []
opt_mixed = [tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))] if len(mod_mixed.functions) == 1:
opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))]
if BuildConfig.current().detect_global_barrier: if BuildConfig.current().detect_global_barrier:
opt_mixed += [tvm.tir.transform.ThreadSync("global")] opt_mixed += [tvm.tir.transform.ThreadSync("global")]
opt_mixed += [tvm.tir.transform.ThreadSync("shared"), opt_mixed += [tvm.tir.transform.ThreadSync("shared"),
...@@ -292,7 +292,7 @@ def build(inputs, ...@@ -292,7 +292,7 @@ def build(inputs,
Parameters Parameters
---------- ----------
inputs : tvm.te.Schedule, LoweredFunc, or dict of target to LoweredFunc list inputs : tvm.te.Schedule, IRModule, or dict of target to IRModule
The schedule to be built The schedule to be built
args : list of Buffer or Tensor or Var, optional args : list of Buffer or Tensor or Var, optional
...@@ -326,7 +326,7 @@ def build(inputs, ...@@ -326,7 +326,7 @@ def build(inputs,
________ ________
There are two typical example uses of this function depending on the type There are two typical example uses of this function depending on the type
of the argument `inputs`: of the argument `inputs`:
1. it is a list of lowered functions: 1. it is an IRModule.
.. code-block:: python .. code-block:: python
...@@ -335,10 +335,10 @@ def build(inputs, ...@@ -335,10 +335,10 @@ def build(inputs,
B = te.placeholder((n,), name='B') B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.te.create_schedule(C.op) s = tvm.te.create_schedule(C.op)
f = tvm.lower(s, [A, B, C], name="test_add") m = tvm.lower(s, [A, B, C], name="test_add")
m = tvm.build(f, target="llvm") rt_mod = tvm.build(m, target="llvm")
2. it is a dict of compilation target to list of lowered functions: 2. it is a dict of compilation target to IRModule.
.. code-block:: python .. code-block:: python
...@@ -349,9 +349,9 @@ def build(inputs, ...@@ -349,9 +349,9 @@ def build(inputs,
s1 = tvm.te.create_schedule(C.op) s1 = tvm.te.create_schedule(C.op)
with tvm.target.cuda() as cuda_tgt: with tvm.target.cuda() as cuda_tgt:
s2 = topi.cuda.schedule_injective(cuda_tgt, [C]) s2 = topi.cuda.schedule_injective(cuda_tgt, [C])
f1 = tvm.lower(s1, [A, B, C], name="test_add1") m1 = tvm.lower(s1, [A, B, C], name="test_add1")
f2 = tvm.lower(s2, [A, B, C], name="test_add2") m2 = tvm.lower(s2, [A, B, C], name="test_add2")
m = tvm.build({"llvm": [f1], "cuda": [f2]}, target_host="llvm") rt_mod = tvm.build({"llvm": m1, "cuda": m2}, target_host="llvm")
Note Note
---- ----
...@@ -360,45 +360,36 @@ def build(inputs, ...@@ -360,45 +360,36 @@ def build(inputs,
if isinstance(inputs, schedule.Schedule): if isinstance(inputs, schedule.Schedule):
if args is None: if args is None:
raise ValueError("args must be given for build from schedule") raise ValueError("args must be given for build from schedule")
flist = lower(inputs, args, input_mod = lower(inputs, args,
name=name, name=name,
binds=binds) binds=binds)
if isinstance(flist, LoweredFunc):
flist = [flist]
elif isinstance(inputs, LoweredFunc):
if args:
raise ValueError("args must be done when build from LoweredFunc.")
flist = [inputs]
elif isinstance(inputs, (list, tuple, container.Array)): elif isinstance(inputs, (list, tuple, container.Array)):
flist = inputs merged_mod = tvm.IRModule({})
for x in inputs:
merged_mod.update(x)
input_mod = merged_mod
elif isinstance(inputs, tvm.IRModule):
input_mod = inputs
elif not isinstance(inputs, (dict, container.Map)): elif not isinstance(inputs, (dict, container.Map)):
raise ValueError("inputs must be Schedule, LoweredFunc, list of " raise ValueError("inputs must be Schedule, IRModule or dict of target to IRModule")
"LoweredFunc, or dict of target to list of "
"LoweredFunc.")
if not isinstance(inputs, (dict, container.Map)): if not isinstance(inputs, (dict, container.Map)):
target = _target.Target.current() if target is None else target target = _target.Target.current() if target is None else target
target = target if target else "llvm" target = target if target else "llvm"
target_flist = {target: flist} target_input_mod = {target: input_mod}
else: else:
target_flist = inputs target_input_mod = inputs
for tar, flist in target_flist.items(): for tar, mod in target_input_mod.items():
if not isinstance(tar, (str, _target.Target)): if not isinstance(tar, (str, _target.Target)):
raise ValueError("The key of inputs must be str or " raise ValueError("The key of inputs must be str or "
"_target.Target when inputs is dict.") "_target.Target when inputs is dict.")
fname_set = set() if not isinstance(mod, tvm.IRModule):
for x in flist: raise ValueError("inputs must be Schedule, IRModule,"
if not isinstance(x, LoweredFunc): "or dict of str to IRModule.")
raise ValueError("inputs must be Schedule, LoweredFunc, list "
"of LoweredFunc, or dict of str to list of "
"LoweredFunc.")
if x.name in fname_set:
raise ValueError("Duplicate function name %s" % x.name)
fname_set.add(x.name)
if not target_host: if not target_host:
for tar, _ in target_flist.items(): for tar, _ in target_input_mod.items():
tar = _target.create(tar) tar = _target.create(tar)
device_type = ndarray.context(tar.target_name, 0).device_type device_type = ndarray.context(tar.target_name, 0).device_type
if device_type == ndarray.cpu(0).device_type: if device_type == ndarray.cpu(0).device_type:
...@@ -410,8 +401,8 @@ def build(inputs, ...@@ -410,8 +401,8 @@ def build(inputs,
mod_host_all = tvm.IRModule({}) mod_host_all = tvm.IRModule({})
device_modules = [] device_modules = []
for tar, flist in target_flist.items(): for tar, input_mod in target_input_mod.items():
mod_host, mdev = _build_for_device(flist, tar, target_host) mod_host, mdev = _build_for_device(input_mod, tar, target_host)
mod_host_all.update(mod_host) mod_host_all.update(mod_host)
device_modules.append(mdev) device_modules.append(mdev)
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
"""The interface of expr function exposed from C++.""" """The interface of expr function exposed from C++."""
import tvm._ffi import tvm._ffi
import tvm.driver import tvm.driver
from tvm.ir import container as _container
@tvm._ffi.register_func("relay.backend.lower") @tvm._ffi.register_func("relay.backend.lower")
...@@ -40,7 +39,7 @@ def lower(sch, inputs, func_name, source_func): ...@@ -40,7 +39,7 @@ def lower(sch, inputs, func_name, source_func):
Returns Returns
------- -------
lowered_funcs : List[tvm.LoweredFunc] mod : tvm.IRModule
The result of lowering. The result of lowering.
""" """
# pylint: disable=broad-except, import-outside-toplevel # pylint: disable=broad-except, import-outside-toplevel
...@@ -56,20 +55,17 @@ def lower(sch, inputs, func_name, source_func): ...@@ -56,20 +55,17 @@ def lower(sch, inputs, func_name, source_func):
msg += "-----------------------------\n" msg += "-----------------------------\n"
msg += source_func.astext() msg += source_func.astext()
raise RuntimeError(msg) raise RuntimeError(msg)
return f if isinstance( return f
f, (_container.Array, tuple, list)) else [f]
@tvm._ffi.register_func("relay.backend.build") @tvm._ffi.register_func("relay.backend.build")
def build(funcs, target, target_host=None): def build(mod, target, target_host=None):
"""Backend build function. """Backend build function.
Parameters Parameters
---------- ----------
funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]] mod : tvm.IRModule or Dict[str, tvm.IRModule]
A list of lowered functions or dictionary mapping from targets to Input module
lowered functions.
target : tvm.Target target : tvm.Target
The target to run the code on. The target to run the code on.
...@@ -84,7 +80,7 @@ def build(funcs, target, target_host=None): ...@@ -84,7 +80,7 @@ def build(funcs, target, target_host=None):
""" """
if target_host == "": if target_host == "":
target_host = None target_host = None
return tvm.driver.build(funcs, target=target, target_host=target_host) return tvm.driver.build(mod, target=target, target_host=target_host)
@tvm._ffi.register_func("relay._tensor_value_repr") @tvm._ffi.register_func("relay._tensor_value_repr")
......
...@@ -48,7 +48,7 @@ class GraphRuntimeCodegen(object): ...@@ -48,7 +48,7 @@ class GraphRuntimeCodegen(object):
self._get_graph_json = self._mod["get_graph_json"] self._get_graph_json = self._mod["get_graph_json"]
self._list_params_name = self._mod["list_params_name"] self._list_params_name = self._mod["list_params_name"]
self._get_param_by_name = self._mod["get_param_by_name"] self._get_param_by_name = self._mod["get_param_by_name"]
self._get_lowered_funcs = self._mod["get_lowered_funcs"] self._get_irmodule = self._mod["get_irmodule"]
self._setup(mod, target) self._setup(mod, target)
def _setup(self, mod, target): def _setup(self, mod, target):
...@@ -74,14 +74,14 @@ class GraphRuntimeCodegen(object): ...@@ -74,14 +74,14 @@ class GraphRuntimeCodegen(object):
------- -------
graph_json : str graph_json : str
The graph json that can be consumed by runtime. The graph json that can be consumed by runtime.
lowered_funcs : List[tvm.LoweredFunc] or Dict[str, List[tvm.LoweredFunc]] mod : IRModule or Dict[str, IRModule]
The lowered functions. The lowered functions.
params : Dict[str, tvm.nd.NDArray] params : Dict[str, tvm.nd.NDArray]
Additional constant parameters. Additional constant parameters.
""" """
self._codegen(func) self._codegen(func)
graph_json = self._get_graph_json() graph_json = self._get_graph_json()
lowered_func = self._get_lowered_funcs() lowered_func = self._get_irmodule()
param_names = self._list_params_name() param_names = self._list_params_name()
params = {} params = {}
for name in param_names: for name in param_names:
......
...@@ -28,3 +28,4 @@ from .object_generic import convert_to_object, convert, const ...@@ -28,3 +28,4 @@ from .object_generic import convert_to_object, convert, const
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev
from .module import load_module, enabled, system_lib from .module import load_module, enabled, system_lib
from .container import String
...@@ -20,9 +20,7 @@ import tvm._ffi ...@@ -20,9 +20,7 @@ import tvm._ffi
import tvm.ir import tvm.ir
from tvm.runtime import Object from tvm.runtime import Object
from tvm.ir import container
from tvm.tir import Stmt from tvm.tir import Stmt
from tvm.tir.stmt import LoweredFunc
from . import _ffi_api from . import _ffi_api
...@@ -48,17 +46,13 @@ class DumpIR(object): ...@@ -48,17 +46,13 @@ class DumpIR(object):
def dump(*args, **kwargs): def dump(*args, **kwargs):
"""dump function""" """dump function"""
retv = func(*args, **kwargs) retv = func(*args, **kwargs)
if not isinstance(retv, (Stmt, LoweredFunc, container.Array)): if not isinstance(retv, (Stmt,)):
return retv return retv
fname = func.func_name if hasattr(func, 'func_name') else func.__name__ fname = func.func_name if hasattr(func, 'func_name') else func.__name__
pname = str(self._pass_id) + "_" + fname + "_ir.cc" pname = str(self._pass_id) + "_" + fname + "_ir.cc"
with open(pname, "a") as f: with open(pname, "a") as f:
out = retv.body if isinstance(retv, LoweredFunc) else retv out = retv
f.write(str(out)) f.write(str(out))
if isinstance(retv, container.Array):
for x in retv:
out = x.body if isinstance(x, LoweredFunc) else x
f.write("---------%s\n%s\n-----------\n"%(x.name, str(out)))
self._pass_id += 1 self._pass_id += 1
return retv return retv
return dump return dump
......
...@@ -14,9 +14,12 @@ ...@@ -14,9 +14,12 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name
""" TVM testing utilities """ """ TVM testing utilities """
import logging import logging
import numpy as np import numpy as np
import tvm
import tvm._ffi import tvm._ffi
...@@ -165,4 +168,40 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No ...@@ -165,4 +168,40 @@ def check_numerical_grads(function, input_values, grad_values, function_value=No
x_name, grad.shape, dist, max_diff, avg_diff) x_name, grad.shape, dist, max_diff, avg_diff)
def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to build a Module from statement.
Used for migrating existing test cases only.
Parameters
----------
stmt: Stmt
The input statement.
name: str
The name of the funciton.
args: list of Buffer or Vars
The function arguments
num_unpacked_args: int
Number of unpacked arguments.
nolias: bool
Whether allow noalias.
Returns
-------
mod : IRModule
The created IRModule.
"""
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.no_alias", True)
mod = tvm.IRModule({name: f})
return tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)
tvm._ffi._init_api("testing", __name__) tvm._ffi._init_api("testing", __name__)
...@@ -29,7 +29,7 @@ from .expr import IterVar, Any ...@@ -29,7 +29,7 @@ from .expr import IterVar, Any
from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .function import PrimFunc from .function import PrimFunc
......
...@@ -55,3 +55,14 @@ def expr_deep_equal(lhs, rhs): ...@@ -55,3 +55,14 @@ def expr_deep_equal(lhs, rhs):
tvm.ir.structural_equal tvm.ir.structural_equal
""" """
return _ffi_api.expr_deep_equal(lhs, rhs) return _ffi_api.expr_deep_equal(lhs, rhs)
def verify_memory(mod):
"""Verify if module contains illegal host side direct memory access.
Parameters
----------
mod: tvm.IRModule
The module to be verified.
"""
_ffi_api.verify_memory(mod)
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import tvm._ffi import tvm._ffi
import tvm.runtime import tvm.runtime
from tvm.runtime import Object
from tvm.ir import BaseFunc from tvm.ir import BaseFunc
from .buffer import Buffer from .buffer import Buffer
from .expr import Var from .expr import Var
...@@ -54,6 +55,7 @@ class PrimFunc(BaseFunc): ...@@ -54,6 +55,7 @@ class PrimFunc(BaseFunc):
param_list = [] param_list = []
buffer_map = {} if buffer_map is None else buffer_map buffer_map = {} if buffer_map is None else buffer_map
for x in params: for x in params:
x = tvm.runtime.convert(x) if not isinstance(x, Object) else x
if isinstance(x, Buffer): if isinstance(x, Buffer):
var = Var(x.name, dtype="handle") var = Var(x.name, dtype="handle")
param_list.append(var) param_list.append(var)
......
...@@ -385,14 +385,6 @@ class Prefetch(Stmt): ...@@ -385,14 +385,6 @@ class Prefetch(Stmt):
_ffi_api.Prefetch, func, value_index, dtype, bounds) _ffi_api.Prefetch, func, value_index, dtype, bounds)
@tvm._ffi.register_object
class LoweredFunc(Object):
"""Represent a LoweredFunc in TVM."""
MixedFunc = 0
HostFunc = 1
DeviceFunc = 2
def stmt_seq(*args): def stmt_seq(*args):
"""Make sequence of statements """Make sequence of statements
......
...@@ -60,6 +60,36 @@ def Filter(fcond): ...@@ -60,6 +60,36 @@ def Filter(fcond):
return _fpass.prim_func_pass(_transform, opt_level=0) return _fpass.prim_func_pass(_transform, opt_level=0)
def LowerCustomDatatypes():
"""Lower custom datatypes.
See tvm::datatypes::Registry for more information on adding custom datatypes.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.LowerCustomDatatypes()
def MakePackedAPI(num_unpacked_params=0):
"""Transform the PrimFuncs in the module to a packed func API.
Parameters
----------
num_unpacked_params : int
Number of parameters that we hope to directly pass via normal arguments
following the PackedFunc input signature.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.MakePackedAPI(num_unpacked_params)
def BindDeviceType(): def BindDeviceType():
"""Bind the device type of the function to be """Bind the device type of the function to be
the device_type specified in the target attribute. the device_type specified in the target attribute.
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/target/codegen.h> #include <tvm/target/codegen.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/te/schedule.h> #include <tvm/te/schedule.h>
#include <map> #include <map>
#include <string> #include <string>
......
...@@ -26,8 +26,10 @@ ...@@ -26,8 +26,10 @@
#include <tvm/te/operation.h> #include <tvm/te/operation.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/target/codegen.h> #include <tvm/target/codegen.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <algorithm> #include <algorithm>
...@@ -39,7 +41,6 @@ namespace tvm { ...@@ -39,7 +41,6 @@ namespace tvm {
using runtime::TVMArgs; using runtime::TVMArgs;
using runtime::TVMRetValue; using runtime::TVMRetValue;
using runtime::PackedFunc; using runtime::PackedFunc;
using tir::LoweredFunc;
bool LLVMEnabled() { bool LLVMEnabled() {
const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm"); const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm");
...@@ -166,17 +167,6 @@ tir::Stmt BuildStmt(te::Schedule sch, ...@@ -166,17 +167,6 @@ tir::Stmt BuildStmt(te::Schedule sch,
return stmt; return stmt;
} }
Array<LoweredFunc> lower(te::Schedule sch,
const Array<te::Tensor>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
const BuildConfig& config) {
Array<ObjectRef> out_arg_list;
auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
return Array<LoweredFunc>({ tir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) });
}
transform::Pass BindTarget(Target target) { transform::Pass BindTarget(Target target) {
auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { auto fpass = [target](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
return WithAttr(std::move(f), tvm::attr::kTarget, target); return WithAttr(std::move(f), tvm::attr::kTarget, target);
...@@ -198,18 +188,46 @@ transform::Pass FilterBy(FCond fcond) { ...@@ -198,18 +188,46 @@ transform::Pass FilterBy(FCond fcond) {
} }
IRModule lower(te::Schedule sch,
const Array<te::Tensor>& args,
const std::string& name,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
const BuildConfig& config) {
Array<ObjectRef> out_arg_list;
auto stmt = BuildStmt(sch, args, binds, true, &out_arg_list, config);
Array<tir::Var> params;
Map<tir::Var, tir::Buffer> buffer_map;
for (auto var : out_arg_list) {
if (auto* n = var.as<tir::VarNode>()) {
params.push_back(GetRef<tir::Var>(n));
} else {
tir::Buffer buffer = Downcast<tir::Buffer>(var);
tir::Var bptr(buffer->name, DataType::Handle());
params.push_back(bptr);
buffer_map.Set(bptr, buffer);
}
}
auto f = tir::PrimFunc(params, stmt, VoidType(), buffer_map);
f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
if (config->restricted_func) {
f = WithAttr(std::move(f), "tir.no_alias", Integer(1));
}
auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
return tir::transform::MakePackedAPI(0)(mod);
}
std::pair<IRModule, IRModule> std::pair<IRModule, IRModule>
split_dev_host_funcs(const Array<LoweredFunc>& funcs, split_dev_host_funcs(IRModule mod_mixed,
const Target& target, const Target& target,
const Target& target_host, const Target& target_host,
const BuildConfig& config) { const BuildConfig& config) {
for (const auto& x : funcs) { mod_mixed = BindTarget(target)(std::move(mod_mixed));
CHECK(tir::VerifyMemory(x, target->device_type)) tir::VerifyMemory(mod_mixed);
<< "Direct host side access to device memory is detected in "
<< x->func_name() << ". Did you forget to bind?";
}
IRModule mod_mixed = codegen::ToIRModule(funcs);
Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target)}; Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target)};
if (config->detect_global_barrier) { if (config->detect_global_barrier) {
...@@ -274,10 +292,9 @@ split_dev_host_funcs(const Array<LoweredFunc>& funcs, ...@@ -274,10 +292,9 @@ split_dev_host_funcs(const Array<LoweredFunc>& funcs,
// Build for heterogeneous execution. // Build for heterogeneous execution.
runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs, runtime::Module build(const Map<Target, IRModule>& inputs,
const Target& target_host, const Target& target_host,
const BuildConfig& config) { const BuildConfig& config) {
Array<LoweredFunc> fhost_all;
std::vector<runtime::Module> device_modules; std::vector<runtime::Module> device_modules;
Target target_host_val = target_host; Target target_host_val = target_host;
...@@ -319,10 +336,10 @@ runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs, ...@@ -319,10 +336,10 @@ runtime::Module build(const Map<Target, Array<LoweredFunc>>& inputs,
} }
// Build for heterogeneous execution when target is a string. // Build for heterogeneous execution when target is a string.
runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs, runtime::Module build(const Map<std::string, IRModule>& inputs,
const Target& target_host, const Target& target_host,
const BuildConfig& config) { const BuildConfig& config) {
Map<Target, Array<LoweredFunc>> updated_input; Map<Target, IRModule> updated_input;
for (const auto& it : inputs) { for (const auto& it : inputs) {
auto target = Target::Create(it.first); auto target = Target::Create(it.first);
if (target->device_name == "vta") { if (target->device_name == "vta") {
...@@ -334,11 +351,11 @@ runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs, ...@@ -334,11 +351,11 @@ runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs,
} }
// Build for homogeneous execution. // Build for homogeneous execution.
runtime::Module build(const Array<LoweredFunc>& funcs, runtime::Module build(const IRModule& funcs,
const Target& target, const Target& target,
const Target& target_host, const Target& target_host,
const BuildConfig& config) { const BuildConfig& config) {
Map<Target, Array<LoweredFunc>> inputs = {{target, funcs}}; Map<Target, IRModule> inputs = {{target, funcs}};
return build(inputs, target_host, config); return build(inputs, target_host, config);
} }
......
...@@ -38,7 +38,6 @@ namespace tvm { ...@@ -38,7 +38,6 @@ namespace tvm {
namespace relay { namespace relay {
namespace backend { namespace backend {
using tir::LoweredFunc;
using TargetsMap = Map<tvm::Integer, tvm::Target>; using TargetsMap = Map<tvm::Integer, tvm::Target>;
using namespace tvm::relay::transform; using namespace tvm::relay::transform;
...@@ -78,16 +77,16 @@ struct GraphCodegen { ...@@ -78,16 +77,16 @@ struct GraphCodegen {
} }
Array<tvm::runtime::Module> GetExternalModules() { Array<tvm::runtime::Module> GetExternalModules() {
return CallFunc<Array<tvm::runtime::Module> >("get_external_modules", nullptr); return CallFunc<Array<tvm::runtime::Module>>("get_external_modules", nullptr);
} }
Map<std::string, Array<LoweredFunc> > GetLoweredFunc() { Map<std::string, IRModule> GetIRModule() {
return CallFunc<Map<std::string, Array<LoweredFunc> > >("get_lowered_funcs", nullptr); return CallFunc<Map<std::string, IRModule>>("get_irmodule", nullptr);
} }
std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() { std::unordered_map<std::string, tvm::runtime::NDArray> GetParams() {
std::unordered_map<std::string, tvm::runtime::NDArray> ret; std::unordered_map<std::string, tvm::runtime::NDArray> ret;
auto names = CallFunc<Array<tvm::PrimExpr> >("list_params_name", nullptr); auto names = CallFunc<Array<tvm::PrimExpr>>("list_params_name", nullptr);
for (auto expr : names) { for (auto expr : names) {
auto key = expr.as<tir::StringImmNode>()->value; auto key = expr.as<tir::StringImmNode>()->value;
ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key); ret[key] = CallFunc<runtime::NDArray>("get_param_by_name", key);
...@@ -152,9 +151,9 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -152,9 +151,9 @@ class RelayBuildModule : public runtime::ModuleNode {
this->SetParam(kv.first, kv.second->data); this->SetParam(kv.first, kv.second->data);
} }
}); });
} else if (name == "get_lowered_funcs") { } else if (name == "get_irmodule") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->graph_codegen_->GetLoweredFunc(); *rv = this->graph_codegen_->GetIRModule();
}); });
} else if (name == "get_external_modules") { } else if (name == "get_external_modules") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
...@@ -452,7 +451,7 @@ class RelayBuildModule : public runtime::ModuleNode { ...@@ -452,7 +451,7 @@ class RelayBuildModule : public runtime::ModuleNode {
ret_.graph_json = graph_codegen_->GetJSON(); ret_.graph_json = graph_codegen_->GetJSON();
ret_.params = graph_codegen_->GetParams(); ret_.params = graph_codegen_->GetParams();
auto lowered_funcs = graph_codegen_->GetLoweredFunc(); auto lowered_funcs = graph_codegen_->GetIRModule();
// When there is no lowered_funcs due to reasons such as optimization. // When there is no lowered_funcs due to reasons such as optimization.
if (lowered_funcs.size() == 0) { if (lowered_funcs.size() == 0) {
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include <tvm/node/structural_equal.h> #include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h> #include <tvm/node/structural_hash.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
...@@ -82,7 +81,8 @@ struct CachedFuncNode : public Object { ...@@ -82,7 +81,8 @@ struct CachedFuncNode : public Object {
/*! \brief The schedule to the function */ /*! \brief The schedule to the function */
te::Schedule schedule; te::Schedule schedule;
/*! \brief The lowered functions to support the function. */ /*! \brief The lowered functions to support the function. */
tvm::Array<tir::LoweredFunc> funcs; IRModule funcs = IRModule::Empty();
/*! \brief Parameter usage states in the shape function. */ /*! \brief Parameter usage states in the shape function. */
tvm::Array<Integer> shape_func_param_states; tvm::Array<Integer> shape_func_param_states;
......
...@@ -55,7 +55,7 @@ using TargetsMap = std::unordered_map<int, Target>; ...@@ -55,7 +55,7 @@ using TargetsMap = std::unordered_map<int, Target>;
/*! \brief Lowered outputs */ /*! \brief Lowered outputs */
struct LoweredOutput { struct LoweredOutput {
std::string graph_json; std::string graph_json;
Map<std::string, Array<tir::LoweredFunc> > lowered_funcs; Map<std::string, IRModule> lowered_funcs;
Array<tvm::runtime::Module> external_mods; Array<tvm::runtime::Module> external_mods;
std::unordered_map<std::string, tvm::runtime::NDArray> params; std::unordered_map<std::string, tvm::runtime::NDArray> params;
}; };
...@@ -214,19 +214,14 @@ class GraphRuntimeCodegen ...@@ -214,19 +214,14 @@ class GraphRuntimeCodegen
LoweredOutput ret; LoweredOutput ret;
ret.graph_json = os.str(); ret.graph_json = os.str();
ret.params = params_; ret.params = params_;
for (auto& kv : lowered_funcs_) { for (auto& kv : lowered_funcs_) {
if (ret.lowered_funcs.count(kv.first) == 0) { if (ret.lowered_funcs.count(kv.first) == 0) {
ret.lowered_funcs.Set(kv.first, Array<tir::LoweredFunc>()); ret.lowered_funcs.Set(kv.first, IRModule::Empty());
}
auto& vec = ret.lowered_funcs[kv.first];
Array<tir::LoweredFunc> tmp;
for (auto f : kv.second) {
tmp.push_back(f);
}
for (auto f : vec) {
tmp.push_back(f);
} }
ret.lowered_funcs.Set(kv.first, tmp); auto& mod = ret.lowered_funcs[kv.first];
mod->Update(kv.second);
ret.lowered_funcs.Set(kv.first, mod);
} }
ret.external_mods = compile_engine_->LowerExternalFunctions(); ret.external_mods = compile_engine_->LowerExternalFunctions();
return ret; return ret;
...@@ -457,12 +452,9 @@ class GraphRuntimeCodegen ...@@ -457,12 +452,9 @@ class GraphRuntimeCodegen
CCacheKey key = (*pf0)(func, target); CCacheKey key = (*pf0)(func, target);
CachedFunc lowered_func = (*pf1)(compile_engine_, key); CachedFunc lowered_func = (*pf1)(compile_engine_, key);
if (!lowered_funcs_.count(target->str())) { if (!lowered_funcs_.count(target->str())) {
lowered_funcs_[target->str()] = {}; lowered_funcs_[target->str()] = IRModule::Empty();
} }
for (auto f : lowered_func->funcs) { lowered_funcs_[target->str()]->Update(lowered_func->funcs);
lowered_funcs_[target->str()].insert(f);
}
return GraphAddCallNode(op, return GraphAddCallNode(op,
_GetUniqueName(lowered_func->func_name), _GetUniqueName(lowered_func->func_name),
lowered_func->func_name); lowered_func->func_name);
...@@ -602,8 +594,7 @@ class GraphRuntimeCodegen ...@@ -602,8 +594,7 @@ class GraphRuntimeCodegen
/*! \brief plan memory of device result */ /*! \brief plan memory of device result */
Map<Expr, Array<IntegerArray>> storage_device_map_; Map<Expr, Array<IntegerArray>> storage_device_map_;
/*! \brief lowered funcs */ /*! \brief lowered funcs */
std::unordered_map<std::string, std::unordered_set<tir::LoweredFunc, ObjectHash, ObjectEqual>> std::unordered_map<std::string, IRModule> lowered_funcs_;
lowered_funcs_;
/*! \brief name map */ /*! \brief name map */
std::unordered_map<std::string, size_t> name_map_; std::unordered_map<std::string, size_t> name_map_;
/*! \brief compile engine */ /*! \brief compile engine */
...@@ -655,7 +646,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { ...@@ -655,7 +646,7 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode {
CHECK_GT(this->output_.params.count(key), 0); CHECK_GT(this->output_.params.count(key), 0);
*rv = this->output_.params[key]; *rv = this->output_.params[key];
}); });
} else if (name == "get_lowered_funcs") { } else if (name == "get_irmodule") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->output_.lowered_funcs; *rv = this->output_.lowered_funcs;
}); });
......
...@@ -226,6 +226,7 @@ std::vector<int64_t> ToAllocTensorShape32(NDArray shape) { ...@@ -226,6 +226,7 @@ std::vector<int64_t> ToAllocTensorShape32(NDArray shape) {
return raw_shape; return raw_shape;
} }
class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
public: public:
VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host)
...@@ -407,12 +408,15 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -407,12 +408,15 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
CCacheKey key(func, target_host_); CCacheKey key(func, target_host_);
auto cfunc = engine_->LowerShapeFunc(key); auto cfunc = engine_->LowerShapeFunc(key);
int op_index = -1; int op_index = -1;
if (context_->seen_funcs.count(cfunc->funcs[0]) == 0) { // pick the only function inside the context
CHECK_EQ(cfunc->funcs->functions.size(), 1);
auto pfunc = Downcast<tir::PrimFunc>((*cfunc->funcs->functions.begin()).second);
if (context_->seen_funcs.count(pfunc) == 0) {
op_index = context_->cached_funcs.size(); op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc); context_->cached_funcs.push_back(cfunc);
context_->seen_funcs[cfunc->funcs[0]] = op_index; context_->seen_funcs[pfunc] = op_index;
} else { } else {
op_index = context_->seen_funcs[cfunc->funcs[0]]; op_index = context_->seen_funcs[pfunc];
} }
// Prepare input and output registers // Prepare input and output registers
...@@ -494,13 +498,14 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> { ...@@ -494,13 +498,14 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
context_->cached_funcs.push_back(cfunc); context_->cached_funcs.push_back(cfunc);
} else { } else {
// TODO(jroesch): support lowered funcs for multiple targets // TODO(jroesch): support lowered funcs for multiple targets
CHECK_EQ(cfunc->funcs.size(), 1); CHECK_EQ(cfunc->funcs->functions.size(), 1);
if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) { auto pfunc = Downcast<tir::PrimFunc>((*cfunc->funcs->functions.begin()).second);
if (context_->seen_funcs.find(pfunc) == context_->seen_funcs.end()) {
op_index = context_->cached_funcs.size(); op_index = context_->cached_funcs.size();
context_->cached_funcs.push_back(cfunc); context_->cached_funcs.push_back(cfunc);
context_->seen_funcs[cfunc->funcs[0]] = op_index; context_->seen_funcs[pfunc] = op_index;
} else { } else {
op_index = context_->seen_funcs[cfunc->funcs[0]]; op_index = context_->seen_funcs[pfunc];
} }
} }
...@@ -862,11 +867,7 @@ void VMCompiler::Lower(IRModule mod, ...@@ -862,11 +867,7 @@ void VMCompiler::Lower(IRModule mod,
// update primitive function map // update primitive function map
size_t primitive_index = 0; size_t primitive_index = 0;
for (const auto& cfunc : context_.cached_funcs) { for (const auto& cfunc : context_.cached_funcs) {
if (cfunc->target->str() == "ext_dev") { exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
exec_->primitive_map.insert({cfunc->func_name, primitive_index++});
} else {
exec_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
}
} }
} }
...@@ -961,8 +962,6 @@ void VMCompiler::PopulateGlobalMap() { ...@@ -961,8 +962,6 @@ void VMCompiler::PopulateGlobalMap() {
} }
void VMCompiler::Codegen() { void VMCompiler::Codegen() {
using tir::LoweredFunc;
if (!context_.module.defined()) { if (!context_.module.defined()) {
LOG(WARNING) << "Did you forget to call VMCompiler::Lower?"; LOG(WARNING) << "Did you forget to call VMCompiler::Lower?";
return; return;
...@@ -971,15 +970,21 @@ void VMCompiler::Codegen() { ...@@ -971,15 +970,21 @@ void VMCompiler::Codegen() {
if (cached_funcs.size() == 0) { if (cached_funcs.size() == 0) {
return; return;
} }
std::unordered_map<std::string, Array<LoweredFunc>> funcs; std::unordered_map<std::string, IRModule> funcs;
for (auto& cfunc : cached_funcs) { for (auto& cfunc : cached_funcs) {
std::string target_str = cfunc->target->str(); std::string target_str = cfunc->target->str();
// NOTE: because module, is mutable, we need to make an
// explicit copy of the IRModule.
IRModule mod = cfunc->funcs;
mod.CopyOnWrite();
if (target_str == "ext_dev") { if (target_str == "ext_dev") {
continue; continue;
} else if (funcs.count(target_str) == 0) { } else if (funcs.count(target_str) == 0) {
funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]}); funcs.emplace(target_str, mod);
} else { } else {
funcs[target_str].push_back(cfunc->funcs[0]); funcs[target_str]->Update(mod);
} }
} }
......
...@@ -76,7 +76,7 @@ struct VMCompilerContext { ...@@ -76,7 +76,7 @@ struct VMCompilerContext {
// List of cached functions // List of cached functions
std::vector<CachedFunc> cached_funcs; std::vector<CachedFunc> cached_funcs;
// The functions that have been lowered. // The functions that have been lowered.
std::unordered_map<tir::LoweredFunc, size_t, ObjectHash, ObjectEqual> seen_funcs; std::unordered_map<tir::PrimFunc, size_t, ObjectHash, ObjectEqual> seen_funcs;
}; };
......
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
* \brief API for Automatic Differentiation for the Relay IR. * \brief API for Automatic Differentiation for the Relay IR.
*/ */
#include <tvm/ir/type_functor.h> #include <tvm/ir/type_functor.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/te/operation.h> #include <tvm/te/operation.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
......
...@@ -31,29 +31,12 @@ ...@@ -31,29 +31,12 @@
#include <tvm/tir/function.h> #include <tvm/tir/function.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h> #include <tvm/tir/stmt.h>
#include <tvm/tir/lowered_func.h>
#include <unordered_map> #include <unordered_map>
#include <string> #include <string>
#include "../runtime/meta_data.h" #include "../runtime/meta_data.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
// Extract function information from device function.
inline std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const Array<tir::LoweredFunc>& funcs) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (tir::LoweredFunc f : funcs) {
runtime::FunctionInfo info;
for (size_t i = 0; i < f->args.size(); ++i) {
info.arg_types.push_back(f->args[i].dtype());
}
for (size_t i = 0; i < f->thread_axis.size(); ++i) {
info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag);
}
fmap[f->name] = info;
}
return fmap;
}
inline std::unordered_map<std::string, runtime::FunctionInfo> inline std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const IRModule& mod) { ExtractFuncInfo(const IRModule& mod) {
......
...@@ -43,50 +43,6 @@ ...@@ -43,50 +43,6 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
// convert legacy LoweredFunc to PrimFunc.
tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) {
// remap args to attach type annotations.
Array<tir::Var> args;
Map<tir::Var, PrimExpr> remap_vars;
for (auto var : from->args) {
auto it = from->handle_data_type.find(var);
if (it != from->handle_data_type.end()) {
tir::Var new_var(var->name_hint,
PointerType(PrimType((*it).second->dtype)));
args.push_back(new_var);
remap_vars.Set(var, new_var);
} else {
args.push_back(var);
}
}
tir::PrimFunc func(args, Substitute(from->body, remap_vars));
func = WithAttr(std::move(func), attr::kGlobalSymbol, runtime::String(from->name));
func = WithAttr(std::move(func), tir::attr::kDeviceThreadAxis, from->thread_axis);
if (from->func_type == tir::LoweredFuncType::kDeviceFunc) {
func = WithAttr(std::move(func),
attr::kCallingConv, Integer(CallingConv::kDeviceKernelLaunch));
}
if (from->is_restricted) {
func = WithAttr(std::move(func), tir::attr::kNoAlias, Integer(1));
}
return func;
}
IRModule ToIRModule(const Array<tir::LoweredFunc>& funcs) {
Map<GlobalVar, BaseFunc> functions;
for (size_t i = 0; i < funcs.size(); ++i) {
auto f = funcs[i];
tir::PrimFunc pf = ToPrimFunc(f);
if (i == 0) {
pf = WithAttr(std::move(pf), tir::attr::kIsEntryFunc, Integer(1));
}
functions.Set(GlobalVar(f->name), pf);
}
return IRModule(functions);
}
runtime::Module Build(IRModule mod, const Target& target) { runtime::Module Build(IRModule mod, const Target& target) {
if (BuildConfig::Current()->disable_assert) { if (BuildConfig::Current()->disable_assert) {
mod = tir::transform::SkipAssert()(mod); mod = tir::transform::SkipAssert()(mod);
...@@ -284,9 +240,6 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod, ...@@ -284,9 +240,6 @@ runtime::Module PackImportsToLLVM(const runtime::Module& mod,
TVM_REGISTER_GLOBAL("target.Build") TVM_REGISTER_GLOBAL("target.Build")
.set_body_typed(Build); .set_body_typed(Build);
TVM_REGISTER_GLOBAL("testing.LoweredFuncsToIRModule")
.set_body_typed(ToIRModule);
// Export two auxiliary function to the runtime namespace. // Export two auxiliary function to the runtime namespace.
TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC") TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC")
.set_body_typed(PackImportsToC); .set_body_typed(PackImportsToC);
......
...@@ -448,7 +448,7 @@ CodeGenLLVM::CreateDebugInfo(llvm::Module* module) { ...@@ -448,7 +448,7 @@ CodeGenLLVM::CreateDebugInfo(llvm::Module* module) {
auto debug_info = llvm::make_unique<CodeGenLLVM::DebugInfo>(); auto debug_info = llvm::make_unique<CodeGenLLVM::DebugInfo>();
debug_info->di_builder_ = llvm::make_unique<llvm::DIBuilder>(*module); debug_info->di_builder_ = llvm::make_unique<llvm::DIBuilder>(*module);
#endif #endif
// TODO(tulloch): pass this information through relay::Span classes to the LoweredFunc instance? // TODO(tulloch): pass this information through relay::Span classes to the IRModule instance?
debug_info->file_ = debug_info->di_builder_->createFile("model.tvm", "/tmp/"); debug_info->file_ = debug_info->di_builder_->createFile("model.tvm", "/tmp/");
debug_info->compilation_unit_ = debug_info->di_builder_->createCompileUnit( debug_info->compilation_unit_ = debug_info->di_builder_->createCompileUnit(
llvm::dwarf::DW_LANG_C, debug_info->file_, "TVM", 0, "", 0, "", llvm::dwarf::DW_LANG_C, debug_info->file_, "TVM", 0, "", 0, "",
......
...@@ -67,20 +67,23 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -67,20 +67,23 @@ class LLVMModuleNode final : public runtime::ModuleNode {
} else if (name == "_get_target_triple") { } else if (name == "_get_target_triple") {
std::string target_triple = tm_->getTargetTriple().str(); std::string target_triple = tm_->getTargetTriple().str();
return PackedFunc([target_triple](TVMArgs args, TVMRetValue *rv) { return PackedFunc([target_triple](TVMArgs args, TVMRetValue *rv) {
* rv = target_triple; *rv = target_triple;
}); });
} }
if (ee_ == nullptr) LazyInitJIT(); if (ee_ == nullptr) LazyInitJIT();
// This LLVMModule is empty and no function can be retrieved.
if (entry_func_.empty()) return nullptr;
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
const std::string& fname = (name == runtime::symbol::tvm_module_main ?
entry_func_ : name);
TVMBackendPackedCFunc faddr = TVMBackendPackedCFunc faddr;
reinterpret_cast<TVMBackendPackedCFunc>(GetFunctionAddr(fname)); if (name == runtime::symbol::tvm_module_main) {
const char* entry_name = reinterpret_cast<const char*>(
GetGlobalAddr(runtime::symbol::tvm_module_main));
CHECK(entry_name != nullptr)
<< "Symbol " << runtime::symbol::tvm_module_main << " is not presented";
faddr = reinterpret_cast<TVMBackendPackedCFunc>(GetFunctionAddr(entry_name));
} else {
faddr = reinterpret_cast<TVMBackendPackedCFunc>(GetFunctionAddr(name));
}
if (faddr == nullptr) return PackedFunc(); if (faddr == nullptr) return PackedFunc();
return WrapPackedFunc(faddr, sptr_to_self); return WrapPackedFunc(faddr, sptr_to_self);
} }
...@@ -205,6 +208,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -205,6 +208,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_.get()); std::unique_ptr<CodeGenLLVM> cg = CodeGenLLVM::Create(tm_.get());
std::vector<PrimFunc> funcs; std::vector<PrimFunc> funcs;
std::string entry_func;
for (auto kv : mod->functions) { for (auto kv : mod->functions) {
CHECK(kv.second->IsInstance<PrimFuncNode>()) CHECK(kv.second->IsInstance<PrimFuncNode>())
<< "Can only lower IR Module with PrimFuncs"; << "Can only lower IR Module with PrimFuncs";
...@@ -212,7 +216,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -212,7 +216,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol); auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
CHECK(global_symbol.defined()); CHECK(global_symbol.defined());
entry_func_ = global_symbol; entry_func = global_symbol;
} }
funcs.push_back(f); funcs.push_back(f);
} }
...@@ -225,8 +229,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -225,8 +229,8 @@ class LLVMModuleNode final : public runtime::ModuleNode {
cg->AddFunction(f); cg->AddFunction(f);
} }
if (entry_func_.length() != 0) { if (entry_func.length() != 0) {
cg->AddMainFunction(entry_func_); cg->AddMainFunction(entry_func);
} }
module_ = cg->Finish(); module_ = cg->Finish();
...@@ -321,13 +325,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -321,13 +325,7 @@ class LLVMModuleNode final : public runtime::ModuleNode {
CHECK(ee_ != nullptr) CHECK(ee_ != nullptr)
<< "Failed to initialize jit engine for " << mptr_->getTargetTriple(); << "Failed to initialize jit engine for " << mptr_->getTargetTriple();
ee_->runStaticConstructorsDestructors(false); ee_->runStaticConstructorsDestructors(false);
// setup context address.
// we will skip context setup if this LLVMModule is empty.
if (GetGlobalAddr(runtime::symbol::tvm_module_main) == 0)
return;
entry_func_ =
reinterpret_cast<const char*>(GetGlobalAddr(runtime::symbol::tvm_module_main));
if (void** ctx_addr = reinterpret_cast<void**>( if (void** ctx_addr = reinterpret_cast<void**>(
GetGlobalAddr(runtime::symbol::tvm_module_ctx))) { GetGlobalAddr(runtime::symbol::tvm_module_ctx))) {
*ctx_addr = this; *ctx_addr = this;
...@@ -356,8 +354,6 @@ class LLVMModuleNode final : public runtime::ModuleNode { ...@@ -356,8 +354,6 @@ class LLVMModuleNode final : public runtime::ModuleNode {
// The target configuration string // The target configuration string
std::string target_; std::string target_;
// Name of entry function.
std::string entry_func_;
// JIT lock // JIT lock
std::mutex mutex_; std::mutex mutex_;
// execution engine // execution engine
......
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
#include <tvm/tir/function.h> #include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/target/codegen.h> #include <tvm/target/codegen.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/runtime/container.h> #include <tvm/runtime/container.h>
#include <string> #include <string>
#include <vector> #include <vector>
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/lowered_func.h>
#include <vector> #include <vector>
#include <memory> #include <memory>
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/target/codegen.h> #include <tvm/target/codegen.h>
#include <string> #include <string>
#include <vector> #include <vector>
......
...@@ -22,8 +22,10 @@ ...@@ -22,8 +22,10 @@
* \brief Pass to check if memory accesses are legal. * \brief Pass to check if memory accesses are legal.
*/ */
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/target/target.h>
#include <tvm/runtime/registry.h>
namespace tvm { namespace tvm {
...@@ -44,7 +46,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { ...@@ -44,7 +46,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
public: public:
/// Special member functions /// Special member functions
//@{ //@{
explicit MemoryAccessVerifier(LoweredFunc f, int device_type) explicit MemoryAccessVerifier(PrimFunc f, int device_type)
: func_(f), dev_type_(device_type) {} : func_(f), dev_type_(device_type) {}
virtual ~MemoryAccessVerifier() = default; virtual ~MemoryAccessVerifier() = default;
MemoryAccessVerifier(const MemoryAccessVerifier &) = delete; MemoryAccessVerifier(const MemoryAccessVerifier &) = delete;
...@@ -116,7 +118,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { ...@@ -116,7 +118,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
CHECK(V) << "Invalid Variable\n"; CHECK(V) << "Invalid Variable\n";
// Variable is from function args. Return true. // Variable is from function args. Return true.
if (V == func_->args[0].get()) return true; if (V == func_->params[0].get()) return true;
// The value is expected to come from a tvm_struct_get Call. // The value is expected to come from a tvm_struct_get Call.
// Get the first argument of tvm_struct_get, and continue. // Get the first argument of tvm_struct_get, and continue.
...@@ -179,18 +181,33 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { ...@@ -179,18 +181,33 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
const ProducerConsumerNode *pc_{nullptr}; const ProducerConsumerNode *pc_{nullptr};
bool failure_{false}; ///< If the verification fails (i.e. has illegal access) bool failure_{false}; ///< If the verification fails (i.e. has illegal access)
//@} //@}
LoweredFunc func_{nullptr}; ///< Function to be verified. tir::PrimFunc func_{nullptr}; ///< Function to be verified.
int dev_type_{kDLCPU}; ///< Device type int dev_type_{kDLCPU}; ///< Device type
std::unordered_map<const VarNode *, PrimExpr> defs_; ///< Variable definitions std::unordered_map<const VarNode *, PrimExpr> defs_; ///< Variable definitions
}; };
} // namespace } // namespace
/// Interface of VerifyMemory pass /// Interface of VerifyMemory pass
bool VerifyMemory(LoweredFunc func, int device_type) { void VerifyMemory(const IRModule& mod) {
MemoryAccessVerifier v(func, device_type); for (auto kv : mod->functions) {
v.Run(); if (auto* n = kv.second.as<PrimFuncNode>()) {
return !v.Failed(); PrimFunc func = GetRef<PrimFunc>(n);
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "LowerWarpMemory: Require the target attribute";
MemoryAccessVerifier v(func, target->device_type);
v.Run();
if (v.Failed()) {
LOG(FATAL)
<< "ValueError: Direct host side access to device memory is detected."
<< " Did you forget to bind?\n"
<< func;
}
}
}
} }
TVM_REGISTER_GLOBAL("tir.analysis.verify_memory")
.set_body_typed(VerifyMemory);
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -48,7 +48,7 @@ Buffer decl_buffer(Array<PrimExpr> shape, ...@@ -48,7 +48,7 @@ Buffer decl_buffer(Array<PrimExpr> shape,
DataType dtype, DataType dtype,
std::string name) { std::string name) {
return BufferNode::make( return BufferNode::make(
Var(name, DataType::Handle()), Var(name, PointerType(PrimType(dtype))),
dtype, dtype,
shape, shape,
Array<PrimExpr>(), Array<PrimExpr>(),
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file lowered_func.cc
*/
#include <tvm/tir/lowered_func.h>
namespace tvm {
namespace tir {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LoweredFuncNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LoweredFuncNode*>(node.get());
p->stream << "LoweredFunc(" << op->name << ", " << op << ")";
});
TVM_REGISTER_NODE_TYPE(LoweredFuncNode);
} // namespace tir
} // namespace tvm
...@@ -105,13 +105,6 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit") ...@@ -105,13 +105,6 @@ TVM_REGISTER_GLOBAL("ir_pass.PostOrderVisit")
}); });
}); });
TVM_REGISTER_GLOBAL("ir_pass.LowerStorageAccess")
.set_body([](TVMArgs args, TVMRetValue *ret) {
LoweredFunc f = args[0];
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = LowerStorageAccessInfo(f->body);
*ret = LoweredFunc(n);
});
// make from two arguments // make from two arguments
#define REGISTER_PASS(PassName) \ #define REGISTER_PASS(PassName) \
...@@ -128,7 +121,6 @@ REGISTER_PASS(VectorizeLoop); ...@@ -128,7 +121,6 @@ REGISTER_PASS(VectorizeLoop);
REGISTER_PASS(SkipVectorize); REGISTER_PASS(SkipVectorize);
REGISTER_PASS(UnrollLoop); REGISTER_PASS(UnrollLoop);
REGISTER_PASS(InjectCopyIntrin); REGISTER_PASS(InjectCopyIntrin);
REGISTER_PASS(MakeAPI);
REGISTER_PASS(StorageRewrite); REGISTER_PASS(StorageRewrite);
REGISTER_PASS(CoProcSync); REGISTER_PASS(CoProcSync);
REGISTER_PASS(LowerStorageAccessInfo); REGISTER_PASS(LowerStorageAccessInfo);
...@@ -138,9 +130,6 @@ REGISTER_PASS(InjectDoubleBuffer); ...@@ -138,9 +130,6 @@ REGISTER_PASS(InjectDoubleBuffer);
REGISTER_PASS(LoopPartition); REGISTER_PASS(LoopPartition);
REGISTER_PASS(RemoveNoOp); REGISTER_PASS(RemoveNoOp);
REGISTER_PASS(LiftAttrScope); REGISTER_PASS(LiftAttrScope);
REGISTER_PASS(RemapThreadAxis);
REGISTER_PASS(LowerCustomDatatypes);
REGISTER_PASS(VerifyMemory);
REGISTER_PASS(VerifyGPUCode); REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope); REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(InstrumentBoundCheckers); REGISTER_PASS(InstrumentBoundCheckers);
......
...@@ -994,29 +994,6 @@ class VectorAllocRewriter : public StmtExprMutator { ...@@ -994,29 +994,6 @@ class VectorAllocRewriter : public StmtExprMutator {
}; };
LoweredFunc PointerValueTypeRewrite(LoweredFunc f) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
VectorAllocRewriter rewriter;
n->body = rewriter(n->body);
for (Var arg : f->args) {
if (arg.dtype().is_handle()) {
const auto& tvec = rewriter.acc_map_[arg.get()];
if (tvec.size() == 1) {
PrimExpr dtype = make_const(tvec[0], 0);
n->handle_data_type.Set(arg, dtype);
} else {
// always set data type to be non vectorized so
// load/store can still work via scalarization
if (tvec.size() != 0 && !n->handle_data_type.count(arg)) {
PrimExpr dtype = make_const(tvec[0].with_lanes(1), 0);
n->handle_data_type.Set(arg, dtype);
}
}
}
}
return LoweredFunc(n);
}
PrimFunc PointerValueTypeRewrite(PrimFunc f) { PrimFunc PointerValueTypeRewrite(PrimFunc f) {
auto* n = f.CopyOnWrite(); auto* n = f.CopyOnWrite();
VectorAllocRewriter rewriter; VectorAllocRewriter rewriter;
......
...@@ -22,7 +22,9 @@ ...@@ -22,7 +22,9 @@
*/ */
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/transform.h>
#include <tvm/target/target.h>
#include <tvm/runtime/registry.h>
#include "../../target/datatype/registry.h" #include "../../target/datatype/registry.h"
namespace tvm { namespace tvm {
...@@ -129,11 +131,26 @@ class CustomDatatypesLowerer : public StmtExprMutator { ...@@ -129,11 +131,26 @@ class CustomDatatypesLowerer : public StmtExprMutator {
std::string target_; std::string target_;
}; };
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) {
auto n = make_object<LoweredFuncNode>(*f.operator->()); namespace transform {
n->body = CustomDatatypesLowerer(target)(n->body);
return LoweredFunc(n); Pass LowerCustomDatatypes() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "LowerCustomDatatypes: Require the target attribute";
n->body = CustomDatatypesLowerer(target->target_name)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {});
} }
TVM_REGISTER_GLOBAL("tir.transform.LowerCustomDatatypes")
.set_body_typed(LowerCustomDatatypes);
} // namespace transform
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -18,20 +18,24 @@ ...@@ -18,20 +18,24 @@
*/ */
/*! /*!
* \file make_api.cc Build API function. * \file make_packed_api.cc Lower PrimFunc to use the packed function API.
*/ */
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/analysis.h> #include <tvm/tir/analysis.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/buffer.h> #include <tvm/tir/buffer.h>
#include <tvm/runtime/device_api.h> #include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <unordered_set> #include <unordered_set>
#include "ir_util.h" #include "../pass/ir_util.h"
#include "arg_binder.h" #include "../pass/arg_binder.h"
namespace tvm { namespace tvm {
namespace tir { namespace tir {
...@@ -40,14 +44,18 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { ...@@ -40,14 +44,18 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0)); return AssertStmtNode::make(lhs == rhs, msg, EvaluateNode::make(0));
} }
LoweredFunc MakeAPI(Stmt body, PrimFunc MakePackedAPI(PrimFunc&& func,
std::string name, int num_unpacked_args) {
Array<ObjectRef> api_args, auto global_symbol = func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
int num_unpacked_args, CHECK(global_symbol.defined())
bool is_restricted) { << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute";
std::string name_hint = global_symbol;
auto* func_ptr = func.CopyOnWrite();
const Stmt nop = EvaluateNode::make(0); const Stmt nop = EvaluateNode::make(0);
int num_args = static_cast<int>(api_args.size()); int num_args = static_cast<int>(func_ptr->params.size());
CHECK_LE(num_unpacked_args, num_args); CHECK_LE(num_unpacked_args, num_args);
int num_packed_args = num_args - num_unpacked_args; int num_packed_args = num_args - num_unpacked_args;
// Data field definitions // Data field definitions
// The packed fields // The packed fields
...@@ -69,9 +77,10 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -69,9 +77,10 @@ LoweredFunc MakeAPI(Stmt body,
// local function definitions // local function definitions
// load i-th argument as type t // load i-th argument as type t
auto f_arg_value = [&](DataType t, int i) { auto f_arg_value = [&](DataType t, int i) {
Array<PrimExpr> call_args{v_packed_args, Array<PrimExpr> call_args{
IntImm(DataType::Int(32), i), v_packed_args,
IntImm(DataType::Int(32), intrinsic::kTVMValueContent)}; IntImm(DataType::Int(32), i),
IntImm(DataType::Int(32), intrinsic::kTVMValueContent)};
// load 64 bit version // load 64 bit version
DataType api_type = APIType(t); DataType api_type = APIType(t);
PrimExpr res = CallNode::make( PrimExpr res = CallNode::make(
...@@ -83,13 +92,7 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -83,13 +92,7 @@ LoweredFunc MakeAPI(Stmt body,
} }
return res; return res;
}; };
// get declaration of argument i
auto f_arg_decl = [&](int i) {
std::ostringstream os;
os << "arg" << i;
const VarNode* v = api_args[i].as<VarNode>();
return Var(os.str(), v ? v->dtype: DataType::Handle());
};
// --------------------------- // ---------------------------
// start of logics // start of logics
// add signiture for packed arguments. // add signiture for packed arguments.
...@@ -99,16 +102,25 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -99,16 +102,25 @@ LoweredFunc MakeAPI(Stmt body,
args.push_back(v_num_packed_args); args.push_back(v_num_packed_args);
std::ostringstream os; std::ostringstream os;
os << name << ": num_args should be " << num_packed_args; os << name_hint << ": num_args should be " << num_packed_args;
seq_init.emplace_back( seq_init.emplace_back(
MakeAssertEQ(v_num_packed_args, num_packed_args, os.str())); MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
} }
// Save the input variables and buffers that will be bound later. // Need to re-declare vars, in case some arguments also appears in the buffer.
std::vector<std::pair<Var, Var> > var_defs; std::vector<std::pair<Var, Var> > var_def;
std::vector<std::pair<Buffer, Var> > buf_defs; std::vector<std::pair<Var, Buffer> > buffer_def;
for (int i = 0; i < static_cast<int>(api_args.size()); ++i) {
Var v_arg = f_arg_decl(i); for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
Var param = func_ptr->params[i];
Var v_arg = Var("arg" + std::to_string(i), param->dtype);
auto it = func_ptr->buffer_map.find(param);
if (it != func_ptr->buffer_map.end()) {
buffer_def.emplace_back(v_arg, (*it).second);
} else {
var_def.emplace_back(v_arg, param);
}
if (i < num_packed_args) { if (i < num_packed_args) {
// Value loads // Value loads
seq_init.emplace_back(LetStmtNode::make( seq_init.emplace_back(LetStmtNode::make(
...@@ -123,35 +135,26 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -123,35 +135,26 @@ LoweredFunc MakeAPI(Stmt body,
DataType t = v_arg.dtype(); DataType t = v_arg.dtype();
if (t.is_handle()) { if (t.is_handle()) {
std::ostringstream msg; std::ostringstream msg;
msg << name << ": Expect arg[" << i << "] to be pointer"; msg << name_hint << ": Expect arg[" << i << "] to be pointer";
seq_check.emplace_back( seq_check.emplace_back(
AssertStmtNode::make(tcode == kTVMOpaqueHandle || AssertStmtNode::make(tcode == kTVMOpaqueHandle ||
tcode == kTVMNDArrayHandle || tcode == kTVMNDArrayHandle ||
tcode == kTVMDLTensorHandle || tcode == kTVMDLTensorHandle ||
tcode == kTVMNullptr, msg.str(), nop)); tcode == kTVMNullptr, msg.str(), nop));
} else if (t.is_int() || t.is_uint()) { } else if (t.is_int() || t.is_uint()) {
std::ostringstream msg; std::ostringstream msg;
msg << name << ": Expect arg[" << i << "] to be int"; msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_check.emplace_back(AssertStmtNode::make(tcode == kDLInt, msg.str(), nop)); seq_check.emplace_back(AssertStmtNode::make(tcode == kDLInt, msg.str(), nop));
} else { } else {
CHECK(t.is_float()); CHECK(t.is_float());
std::ostringstream msg; std::ostringstream msg;
msg << name << ": Expect arg[" << i << "] to be float"; msg << name_hint << ": Expect arg[" << i << "] to be float";
seq_check.emplace_back( seq_check.emplace_back(
AssertStmtNode::make(tcode == kDLFloat, msg.str(), nop)); AssertStmtNode::make(tcode == kDLFloat, msg.str(), nop));
} }
} else { } else {
args.push_back(v_arg); args.push_back(v_arg);
} }
// add checks for functions.
if (api_args[i].as<VarNode>()) {
var_defs.emplace_back(std::make_pair(Downcast<Var>(api_args[i]), v_arg));
} else {
// Buffer checks
CHECK(api_args[i].as<BufferNode>())
<< "api_args can only be Buffer or Var";
buf_defs.emplace_back(std::make_pair(Downcast<Buffer>(api_args[i]), v_arg));
}
} }
// allow return value if the function is packed. // allow return value if the function is packed.
...@@ -170,24 +173,22 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -170,24 +173,22 @@ LoweredFunc MakeAPI(Stmt body,
// either 0 or the original stride will be correctly used. Checks here have // either 0 or the original stride will be correctly used. Checks here have
// to use the args that may have no let bining yet. Therefore, hoisting let // to use the args that may have no let bining yet. Therefore, hoisting let
// binding for args before buffer declaration is needed. // binding for args before buffer declaration is needed.
for (const auto& arg : var_defs) { for (const auto& kv : var_def) {
binder.Bind(arg.first, arg.second, arg.second->name_hint, true); binder.Bind(kv.second, kv.first, kv.first->name_hint, true);
}
for (const auto& kv : buffer_def) {
binder.BindDLTensor(kv.second, device_type, device_id,
kv.first, kv.first->name_hint);
} }
for (const auto& buf_arg : buf_defs) { if (num_unpacked_args == 0) {
binder.BindDLTensor(buf_arg.first, device_type, device_id, func = WithAttr(std::move(func), tvm::attr::kCallingConv, Integer(CallingConv::kCPackedFunc));
buf_arg.second, buf_arg.second->name_hint);
} }
ObjectPtr<LoweredFuncNode> n = make_object<LoweredFuncNode>(); auto body = AttrStmtNode::make(
n->name = name;
n->args = args;
n->handle_data_type = binder.def_handle_dtype();
n->is_packed_func = num_unpacked_args == 0;
n->is_restricted = is_restricted;
body = AttrStmtNode::make(
make_zero(DataType::Int(32)), attr::compute_scope, make_zero(DataType::Int(32)), attr::compute_scope,
StringImmNode::make(name + "_compute_"), body); StringImmNode::make(name_hint + "_compute_"), func_ptr->body);
// Set device context // Set device context
if (vmap.count(device_id.get())) { if (vmap.count(device_id.get())) {
PrimExpr node = StringImmNode::make("default"); PrimExpr node = StringImmNode::make("default");
...@@ -203,21 +204,59 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -203,21 +204,59 @@ LoweredFunc MakeAPI(Stmt body,
device_type, device_id}, CallNode::Intrinsic))); device_type, device_id}, CallNode::Intrinsic)));
body = SeqStmt({set_device, body}); body = SeqStmt({set_device, body});
} }
n->body = MergeNest( func_ptr->body = MergeNest(
{seq_init, binder.init_nest(), seq_check, binder.asserts()}, body); {seq_init, binder.init_nest(), seq_check, binder.asserts()}, body);
LoweredFunc f(n); func_ptr->params = args;
Array<Var> undefined = UndefinedVars(f->body, f->args);
Array<Var> undefined = UndefinedVars(func_ptr->body, func_ptr->params);
if (undefined.size() != 0) { if (undefined.size() != 0) {
std::ostringstream os; std::ostringstream os;
for (Var v : undefined) { for (Var v : undefined) {
os << " \'" << v->name_hint << "\' "; os << " \'" << v->name_hint << "\' ";
} }
os << " does not appear in api_args"; os << " is not bound to any variables";
LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str(); LOG(FATAL) << "Not all Vars are passed in api_args: " << os.str();
} }
return f;
func_ptr->buffer_map = Map<Var, Buffer>();
func_ptr->checked_type_ = func_ptr->func_type_annotation();
func_ptr->ret_type = PrimType(DataType::Int(32));
// return the function.
return std::move(func);
} }
namespace transform {
Pass MakePackedAPI(int num_unpacked_args) {
auto pass_func = [num_unpacked_args](IRModule m, PassContext ctx) {
IRModuleNode* mptr = m.CopyOnWrite();
std::vector<std::pair<GlobalVar, PrimFunc> > updates;
for (const auto& kv : mptr->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
PrimFunc func = GetRef<PrimFunc>(n);
if (func->GetAttr<Integer>(tvm::attr::kCallingConv, 0)->value
== static_cast<int>(CallingConv::kDefault)) {
auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args);
updates.push_back({kv.first, updated_func});
}
}
}
for (const auto& pair : updates) {
mptr->AddUnchecked(pair.first, pair.second);
}
return m;
};
return tvm::transform::CreateModulePass(
pass_func, 0, "tir.MakePackedAPI", {});
}
TVM_REGISTER_GLOBAL("tir.transform.MakePackedAPI")
.set_body_typed(MakePackedAPI);
} // namespace transform
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
*/ */
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/transform.h>
#include <tvm/runtime/registry.h>
#include <unordered_map> #include <unordered_map>
...@@ -74,8 +75,8 @@ class ThreadAxisRewriter : private StmtExprMutator { ...@@ -74,8 +75,8 @@ class ThreadAxisRewriter : private StmtExprMutator {
std::unordered_map<const VarNode*, Var> vmap_; std::unordered_map<const VarNode*, Var> vmap_;
}; };
LoweredFunc
RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> thread_map) { PrimFunc RemapThreadAxis(PrimFunc&& f, Map<PrimExpr, IterVar> thread_map) {
std::unordered_map<std::string, IterVar> tmap; std::unordered_map<std::string, IterVar> tmap;
for (const auto& kv : thread_map) { for (const auto& kv : thread_map) {
const StringImmNode* str = kv.first.as<StringImmNode>(); const StringImmNode* str = kv.first.as<StringImmNode>();
...@@ -83,18 +84,33 @@ RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> thread_map) { ...@@ -83,18 +84,33 @@ RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> thread_map) {
tmap[str->value] = kv.second; tmap[str->value] = kv.second;
} }
CHECK_EQ(f->func_type, kDeviceFunc); auto thread_axis = f->GetAttr<Array<IterVar> >(tir::attr::kDeviceThreadAxis);
auto n = make_object<LoweredFuncNode>(*f.operator->()); auto* n = f.CopyOnWrite();
// replace the thread axis // replace the thread axis
for (size_t i = 0; i < n->thread_axis.size(); ++i) { for (size_t i = 0; i < thread_axis.size(); ++i) {
auto it = tmap.find(n->thread_axis[i]->thread_tag); auto it = tmap.find(thread_axis[i]->thread_tag);
if (it != tmap.end()) { if (it != tmap.end()) {
n->thread_axis.Set(i, it->second); thread_axis.Set(i, it->second);
} }
} }
n->body = ThreadAxisRewriter(tmap).Rewrite(n->body); n->body = ThreadAxisRewriter(tmap).Rewrite(std::move(n->body));
return LoweredFunc(n); return WithAttr(std::move(f), tir::attr::kDeviceThreadAxis, thread_axis);
} }
namespace transform {
Pass RemapThreadAxis(Map<PrimExpr, IterVar> thread_map) {
auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) {
return RemapThreadAxis(std::move(f), thread_map);
};
return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis", {});
}
TVM_REGISTER_GLOBAL("tir.transform.RemapThreadAxis")
.set_body_typed(RemapThreadAxis);
} // namespace transform
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -264,7 +264,6 @@ class HostDeviceSplitter : public StmtMutator { ...@@ -264,7 +264,6 @@ class HostDeviceSplitter : public StmtMutator {
std::string name_prefix_; std::string name_prefix_;
// Number of device functions. // Number of device functions.
int device_func_counter_{0}; int device_func_counter_{0};
std::vector<LoweredFunc> device_funcs_;
std::unordered_map<const VarNode*, PrimExpr> handle_data_type_; std::unordered_map<const VarNode*, PrimExpr> handle_data_type_;
}; };
......
...@@ -117,8 +117,8 @@ TEST(BuildModule, Heterogeneous) { ...@@ -117,8 +117,8 @@ TEST(BuildModule, Heterogeneous) {
std::unordered_map<Tensor, Buffer> binds; std::unordered_map<Tensor, Buffer> binds;
auto lowered_s1 = lower(s1, args1, "elemwise_add", binds, config); auto lowered_s1 = lower(s1, args1, "elemwise_add", binds, config);
auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds, config); auto lowered_s2 = lower(s2, args2, "elemwise_sub", binds, config);
Map<tvm::Target, Array<LoweredFunc>> inputs = {{target_cuda, lowered_s1}, Map<tvm::Target, IRModule> inputs = {{target_cuda, lowered_s1},
{target_llvm, lowered_s2}}; {target_llvm, lowered_s2}};
auto module = build(inputs, Target(), config); auto module = build(inputs, Target(), config);
// Assertion for build. // Assertion for build.
......
...@@ -18,29 +18,6 @@ import tvm ...@@ -18,29 +18,6 @@ import tvm
from tvm import te from tvm import te
import numpy as np import numpy as np
def lower(s, args, name="mydot"):
binds = {}
arg_list = []
for x in args:
assert isinstance(x, te.tensor.Tensor)
buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.op.name)
binds[x] = buf
arg_list.append(buf)
s = s.normalize()
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 16)
stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
stmt = tvm.tir.ir_pass.Simplify(stmt)
fapi = tvm.tir.ir_pass.MakeAPI(stmt, name, arg_list, 0, True)
fapi = tvm.tir.ir_pass.LowerTVMBuiltin(fapi)
return fapi
def mybuild(fapi, target="llvm"):
return
def test_dot(): def test_dot():
nn = 12 nn = 12
......
...@@ -38,8 +38,9 @@ def test_dltensor_compatible(): ...@@ -38,8 +38,9 @@ def test_dltensor_compatible():
with ib.for_range(0, n - 1, "i") as i: with ib.for_range(0, n - 1, "i") as i:
A[i + 1] = A[i] + 1 A[i + 1] = A[i] + 1
stmt = ib.get() stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "arange", [Ab], 0, True)
mod = tvm.testing.LoweredFuncsToIRModule([fapi])
mod = tvm.testing.MakeAPILegacy(stmt, "arange", [Ab], 0, True)
mod = tvm.tir.transform.LowerTVMBuiltin()(mod) mod = tvm.tir.transform.LowerTVMBuiltin()(mod)
f = tvm.target.codegen.build_module(mod, "stackvm") f = tvm.target.codegen.build_module(mod, "stackvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
......
...@@ -156,7 +156,7 @@ def test_simplex_data_transferring(): ...@@ -156,7 +156,7 @@ def test_simplex_data_transferring():
elemwise_sub], elemwise_sub],
name="elemwise_sub") name="elemwise_sub")
target_flist = {target_device: [lower_add], target_host: [lower_sub]} target_flist = {target_device: lower_add, target_host: lower_sub}
mhost = tvm.build(target_flist, target_host=target_host) mhost = tvm.build(target_flist, target_host=target_host)
ctx = [host_ctx, device_ctx] ctx = [host_ctx, device_ctx]
mod = graph_runtime.create(graph, mhost, ctx) mod = graph_runtime.create(graph, mhost, ctx)
...@@ -354,8 +354,9 @@ def test_duplex_data_transferring(): ...@@ -354,8 +354,9 @@ def test_duplex_data_transferring():
elemwise_sub], elemwise_sub],
name="elemwise_sub") name="elemwise_sub")
target_flist = {target_device: [lower_add0, lower_add1], target_host: lower_add0.update(lower_add1)
[lower_sub]} target_flist = {target_device: lower_add0, target_host:
lower_sub}
mhost = tvm.build(target_flist, target_host=target_host) mhost = tvm.build(target_flist, target_host=target_host)
ctx = [host_ctx, device_ctx] ctx = [host_ctx, device_ctx]
params = {} params = {}
......
...@@ -57,8 +57,8 @@ def test_dso_module_load(): ...@@ -57,8 +57,8 @@ def test_dso_module_load():
tvm.tir.Store(Ab.data, tvm.tir.Store(Ab.data,
tvm.tir.Load(dtype, Ab.data, i) + 1, tvm.tir.Load(dtype, Ab.data, i) + 1,
i + 1)) i + 1))
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) m = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
m = tvm.driver.build(fapi, target="llvm") m = tvm.driver.build(m, target="llvm")
for name in names: for name in names:
m.save(name) m.save(name)
......
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
import ctypes import ctypes
import math import math
def test_llvm_intrin(): def test_llvm_intrin():
ib = tvm.tir.ir_builder.create() ib = tvm.tir.ir_builder.create()
n = tvm.runtime.convert(4) n = tvm.runtime.convert(4)
...@@ -34,7 +35,8 @@ def test_llvm_intrin(): ...@@ -34,7 +35,8 @@ def test_llvm_intrin():
tvm.tir.Call( tvm.tir.Call(
"int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0))) "int32", "prefetch", args, tvm.tir.Call.Intrinsic, None, 0)))
body = ib.get() body = ib.get()
func = tvm.tir.ir_pass.MakeAPI(body, "prefetch", [A], 0, True)
func = tvm.testing.MakeAPILegacy(body, "prefetch", [A], 0, True)
fcode = tvm.build(func, None, "llvm") fcode = tvm.build(func, None, "llvm")
...@@ -85,7 +87,7 @@ def test_llvm_lookup_intrin(): ...@@ -85,7 +87,7 @@ def test_llvm_lookup_intrin():
x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.tir.const(1, 'uint32'), A) x = tvm.tir.call_llvm_intrin("uint8x8", "llvm.ctpop.i8", tvm.tir.const(1, 'uint32'), A)
ib.emit(x) ib.emit(x)
body = ib.get() body = ib.get()
func = tvm.tir.ir_pass.MakeAPI(body, "ctpop", [A], 1, True) func = tvm.testing.MakeAPILegacy(body, "ctpop", [A], 1, True)
fcode = tvm.build(func, None, "llvm") fcode = tvm.build(func, None, "llvm")
...@@ -307,8 +309,9 @@ def test_multiple_func(): ...@@ -307,8 +309,9 @@ def test_multiple_func():
f2 = tvm.lower(s, [A, B, C], name="fadd1") f2 = tvm.lower(s, [A, B, C], name="fadd1")
f1 = tvm.lower(s, [A, B, C], name="fadd2") f1 = tvm.lower(s, [A, B, C], name="fadd2")
m = tvm.build([f1, f2], "llvm") m = tvm.build([f1, f2], "llvm")
fadd1 = m['fadd1']
fadd2 = m['fadd2'] fadd2 = m['fadd2']
fadd1 = m['fadd1']
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
# launch the kernel. # launch the kernel.
n = nn n = nn
...@@ -665,6 +668,7 @@ def test_llvm_shuffle(): ...@@ -665,6 +668,7 @@ def test_llvm_shuffle():
tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32')) tvm.testing.assert_allclose(c_.asnumpy(), (a_.asnumpy() * 2).astype('int32'))
if __name__ == "__main__": if __name__ == "__main__":
test_multiple_func()
test_llvm_large_uintimm() test_llvm_large_uintimm()
test_llvm_import() test_llvm_import()
test_alignment() test_alignment()
...@@ -676,7 +680,6 @@ if __name__ == "__main__": ...@@ -676,7 +680,6 @@ if __name__ == "__main__":
test_llvm_vadd_pipeline() test_llvm_vadd_pipeline()
test_llvm_add_pipeline() test_llvm_add_pipeline()
test_llvm_intrin() test_llvm_intrin()
test_multiple_func()
test_llvm_flip_pipeline() test_llvm_flip_pipeline()
test_llvm_madd_pipeline() test_llvm_madd_pipeline()
test_llvm_temp_space() test_llvm_temp_space()
......
...@@ -19,6 +19,18 @@ from tvm import te ...@@ -19,6 +19,18 @@ from tvm import te
import ctypes import ctypes
import numpy as np import numpy as np
def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to create a API"""
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.no_alias", True)
mod = tvm.IRModule.from_expr(f)
return tvm.tir.transform.MakePackedAPI()(mod)
def test_static_callback(): def test_static_callback():
dtype = 'int64' dtype = 'int64'
n = te.size_var('n') n = te.size_var('n')
...@@ -32,7 +44,7 @@ def test_static_callback(): ...@@ -32,7 +44,7 @@ def test_static_callback():
with ib.for_range(0, n, "i", for_type="parallel") as i: with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1 A[i] = A[i] + 1
stmt = ib.get() stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
f = tvm.driver.build(fapi, target="llvm") f = tvm.driver.build(fapi, target="llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
...@@ -55,7 +67,7 @@ def test_static_init(): ...@@ -55,7 +67,7 @@ def test_static_init():
return sh return sh
stmt = ib.get() stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
f = tvm.driver.build(fapi, target="llvm") f = tvm.driver.build(fapi, target="llvm")
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
......
...@@ -26,6 +26,18 @@ def run_jit(fapi, check): ...@@ -26,6 +26,18 @@ def run_jit(fapi, check):
s = f.get_source() s = f.get_source()
check(f) check(f)
def MakeAPILegacy(stmt, name, args, num_unpacked_args, noalias):
"""Legacy adapter to create a API"""
f = tvm.tir.PrimFunc(args, stmt).with_attr(
"global_symbol", tvm.runtime.String(name))
f = f.with_attr("tir.is_entry_func", True)
if noalias:
f = f.with_attr("tir.no_alias", True)
mod = tvm.IRModule.from_expr(f)
return tvm.tir.transform.MakePackedAPI()(mod)
def test_stack_vm_basic(): def test_stack_vm_basic():
a = tvm.nd.array(np.zeros(10, dtype='float32')) a = tvm.nd.array(np.zeros(10, dtype='float32'))
@tvm.register_func @tvm.register_func
...@@ -36,7 +48,7 @@ def test_stack_vm_basic(): ...@@ -36,7 +48,7 @@ def test_stack_vm_basic():
n = te.size_var('n') n = te.size_var('n')
Ab = tvm.tir.decl_buffer((n, ), "float32") Ab = tvm.tir.decl_buffer((n, ), "float32")
stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0])) stmt = tvm.tir.Evaluate(tvm.tir.call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0, True) fapi = tvm.testing.MakeAPILegacy(stmt, "print_shape", [Ab], 0, True)
run_jit(fapi, lambda f: f(a)) run_jit(fapi, lambda f: f(a))
...@@ -57,7 +69,7 @@ def test_stack_vm_loop(): ...@@ -57,7 +69,7 @@ def test_stack_vm_loop():
ib.emit(tvm.tir.call_packed("tvm_stack_vm_print", i)) ib.emit(tvm.tir.call_packed("tvm_stack_vm_print", i))
stmt = ib.get() stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f): def check(f):
f(a) f(a)
...@@ -79,7 +91,7 @@ def test_stack_vm_cond(): ...@@ -79,7 +91,7 @@ def test_stack_vm_cond():
A[i + 1] = A[i] + 2 A[i + 1] = A[i] + 2
stmt = ib.get() stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab], 0, True) fapi = tvm.testing.MakeAPILegacy(stmt, "test", [Ab], 0, True)
def check(f): def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
...@@ -98,7 +110,7 @@ def test_vm_parallel(): ...@@ -98,7 +110,7 @@ def test_vm_parallel():
with ib.for_range(0, n, "i", for_type="parallel") as i: with ib.for_range(0, n, "i", for_type="parallel") as i:
A[i] = A[i] + 1 A[i] = A[i] + 1
stmt = ib.get() stmt = ib.get()
fapi = tvm.tir.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0, True) fapi = tvm.testing.MakeAPILegacy(stmt, "ramp", [Ab], 0, True)
def check(f): def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a) f(a)
......
...@@ -19,7 +19,6 @@ import tvm ...@@ -19,7 +19,6 @@ import tvm
from tvm import te from tvm import te
from ctypes import * from ctypes import *
import topi import topi
import tvm.tir.ir_pass as ir_pass
import numpy as np import numpy as np
tgt = "llvm" tgt = "llvm"
...@@ -51,10 +50,12 @@ def lower_datatypes_and_build(schedule, args): ...@@ -51,10 +50,12 @@ def lower_datatypes_and_build(schedule, args):
Once datatype lowering is integrated directly into TVM's lower/build Once datatype lowering is integrated directly into TVM's lower/build
process, we won't need to do this manually. process, we won't need to do this manually.
TODO(gus) integrate datatype lowering into build process; change this test""" TODO(gus) integrate datatype lowering into build process; change this test"""
flist = tvm.lower(schedule, args) mod = tvm.lower(schedule, args)
flist = [flist] target = tvm.target.create(tgt)
flist = [ir_pass.LowerCustomDatatypes(func, tgt) for func in flist] mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
return tvm.build(flist[0], target=tgt) mod = tvm.tir.transform.LowerCustomDatatypes()(mod)
return tvm.build(mod, target=tgt)
def test_bfloat_add_and_cast_1(): def test_bfloat_add_and_cast_1():
X = te.placeholder((3, ), name="X") X = te.placeholder((3, ), name="X")
......
...@@ -15,12 +15,13 @@ ...@@ -15,12 +15,13 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import tvm import tvm
import pytest
from tvm import te from tvm import te
# The following DLDeviceType/TVMDeviceExtType values # The following DLDeviceType/TVMDeviceExtType values
# are originally defined in dlpack.h and c_runtime_api.h. # are originally defined in dlpack.h and c_runtime_api.h.
gpu_devices = [2, 4, 7, 8, 10, 11] gpu_devices = ["cuda", "opencl", "metal", "vulkan"]
other_devices = [1, 3, 9, 12] other_devices = ["llvm", "ext_dev"]
def lower(sch, args): def lower(sch, args):
...@@ -39,8 +40,11 @@ def lower(sch, args): ...@@ -39,8 +40,11 @@ def lower(sch, args):
stmt = tvm.te.schedule.ScheduleOps(sch, bounds) stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64) stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64)
func = tvm.tir.ir_pass.MakeAPI(stmt, "myadd", arg_list, 0, True)
return func f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
"global_symbol", tvm.runtime.String("test"))
mod = tvm.IRModule({"test": f})
return tvm.tir.transform.MakePackedAPI()(mod)
# All computations are bound. # All computations are bound.
...@@ -57,10 +61,13 @@ def test_verify_memory_all_bind(): ...@@ -57,10 +61,13 @@ def test_verify_memory_all_bind():
s[B].bind(bx, te.thread_axis("blockIdx.x")) s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x")) s[B].bind(tx, te.thread_axis("threadIdx.x"))
func = lower(s, [A, B]) mod = lower(s, [A, B])
for dev_type in gpu_devices + other_devices: for dev_type in gpu_devices + other_devices:
assert tvm.tir.ir_pass.VerifyMemory(func, dev_type) binded_mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
tvm.tir.analysis.verify_memory(binded_mod)
# Computations are not bound. # Computations are not bound.
...@@ -74,12 +81,18 @@ def test_verify_memory_not_bind(): ...@@ -74,12 +81,18 @@ def test_verify_memory_not_bind():
# B is not bound to threads. # B is not bound to threads.
s = te.create_schedule(B.op) s = te.create_schedule(B.op)
func = lower(s, [A, B]) mod = lower(s, [A, B])
for dev_type in gpu_devices: for dev_type in gpu_devices:
assert not tvm.tir.ir_pass.VerifyMemory(func, dev_type) binded_mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
with pytest.raises(ValueError):
tvm.tir.analysis.verify_memory(binded_mod)
for dev_type in other_devices: for dev_type in other_devices:
assert tvm.tir.ir_pass.VerifyMemory(func, dev_type) binded_mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
tvm.tir.analysis.verify_memory(binded_mod)
# Computations are partially bound. # Computations are partially bound.
...@@ -98,16 +111,22 @@ def test_verify_memory_partially_bind(): ...@@ -98,16 +111,22 @@ def test_verify_memory_partially_bind():
s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x"))
func = lower(s, [A, B, C, D]) mod = lower(s, [A, B, C, D])
for dev_type in gpu_devices: for dev_type in gpu_devices:
assert not tvm.tir.ir_pass.VerifyMemory(func, dev_type) binded_mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
with pytest.raises(ValueError):
tvm.tir.analysis.verify_memory(binded_mod)
for dev_type in other_devices: for dev_type in other_devices:
assert tvm.tir.ir_pass.VerifyMemory(func, dev_type) binded_mod = tvm.tir.transform.Apply(
lambda f: f.with_attr("target", tvm.target.create(dev_type)))(mod)
tvm.tir.analysis.verify_memory(binded_mod)
if __name__ == "__main__": if __name__ == "__main__":
test_verify_memory_all_bind() test_verify_memory_all_bind()
test_verify_memory_not_bind() test_verify_memory_not_bind()
test_verify_memory_partially_bind() test_verify_memory_partially_bind()
...@@ -118,7 +118,6 @@ def test_in_bounds_vectorize_llvm(): ...@@ -118,7 +118,6 @@ def test_in_bounds_vectorize_llvm():
s[B].vectorize(xi) s[B].vectorize(xi)
# build and invoke the kernel. # build and invoke the kernel.
lowered_func = tvm.lower (s, [A, C], "llvm", simple_mode=False) lowered_func = tvm.lower (s, [A, C], "llvm", simple_mode=False)
print (lowered_func.body)
f = tvm.build(s, [A, C], "llvm") f = tvm.build(s, [A, C], "llvm")
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
# launch the kernel. # launch the kernel.
...@@ -137,7 +136,6 @@ def test_in_bounds_loop_partition_basic_llvm(): ...@@ -137,7 +136,6 @@ def test_in_bounds_loop_partition_basic_llvm():
s = te.create_schedule(T.op) s = te.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4) xo, xi = s[T].split(T.op.axis[0], factor=4)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm") f = tvm.build(s, [A, B, T], "llvm")
...@@ -156,7 +154,6 @@ def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b): ...@@ -156,7 +154,6 @@ def test_out_of_bounds_loop_partition_basic_llvm(index_a, index_b):
s = te.create_schedule(T.op) s = te.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4) xo, xi = s[T].split(T.op.axis[0], factor=4)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm") f = tvm.build(s, [A, B, T], "llvm")
...@@ -205,12 +202,11 @@ def test_in_bounds_const_loop_partition_ir(): ...@@ -205,12 +202,11 @@ def test_in_bounds_const_loop_partition_ir():
# after instrumentation # after instrumentation
assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3) assert_bound_instrumentation(stmt, check_attr_stmt, 2 * 3)
assert_bound_instrumentation(stmt, check_branch_stmt, 2) assert_bound_instrumentation(stmt, check_branch_stmt, 2)
print (stmt)
branch_collector = list() branch_collector = list()
collect_visit(stmt, collect_branch_stmt) collect_visit(stmt, collect_branch_stmt)
assert(len(branch_collector) == 2) assert(len(branch_collector) == 2)
print (branch_collector[0].condition)
print (branch_collector[1].condition)
def test_in_bounds_const_loop_partition_llvm(): def test_in_bounds_const_loop_partition_llvm():
with tvm.target.build_config(instrument_bound_checkers=True, partition_const_loop=True): with tvm.target.build_config(instrument_bound_checkers=True, partition_const_loop=True):
...@@ -222,7 +218,6 @@ def test_in_bounds_const_loop_partition_llvm(): ...@@ -222,7 +218,6 @@ def test_in_bounds_const_loop_partition_llvm():
s = te.create_schedule(T.op) s = te.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4) xo, xi = s[T].split(T.op.axis[0], factor=4)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm") f = tvm.build(s, [A, B, T], "llvm")
...@@ -242,7 +237,6 @@ def test_out_of_bounds_const_loop_partition_llvm(index_a, index_b): ...@@ -242,7 +237,6 @@ def test_out_of_bounds_const_loop_partition_llvm(index_a, index_b):
s = te.create_schedule(T.op) s = te.create_schedule(T.op)
xo, xi = s[T].split(T.op.axis[0], factor=4) xo, xi = s[T].split(T.op.axis[0], factor=4)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm") f = tvm.build(s, [A, B, T], "llvm")
...@@ -276,7 +270,6 @@ def test_in_bounds_conv_llvm(loop_tiling=False): ...@@ -276,7 +270,6 @@ def test_in_bounds_conv_llvm(loop_tiling=False):
if loop_tiling: if loop_tiling:
oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16) oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True) lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True)
print (lowered_func.body)
ctx = tvm.cpu (0) ctx = tvm.cpu (0)
f = tvm.build(s, [data, kernel, conv], "llvm") f = tvm.build(s, [data, kernel, conv], "llvm")
...@@ -320,7 +313,6 @@ def test_out_of_bounds_conv_llvm(data_offsets, kernel_offsets, loop_tiling=False ...@@ -320,7 +313,6 @@ def test_out_of_bounds_conv_llvm(data_offsets, kernel_offsets, loop_tiling=False
if loop_tiling: if loop_tiling:
oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16) oho, owo, ohi, owi = s[conv].tile(oh, ow, 16, 16)
lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True) lowered_func = tvm.lower(s, [data, kernel, conv], simple_mode=True)
print (lowered_func.body)
ctx = tvm.cpu (0) ctx = tvm.cpu (0)
f = tvm.build(s, [data, kernel, conv], "llvm") f = tvm.build(s, [data, kernel, conv], "llvm")
...@@ -341,7 +333,6 @@ def test_in_bounds_tensors_with_same_shapes1D_llvm(): ...@@ -341,7 +333,6 @@ def test_in_bounds_tensors_with_same_shapes1D_llvm():
T = te.compute((m, ), lambda i: A[i]*B[i]) T = te.compute((m, ), lambda i: A[i]*B[i])
s = te.create_schedule(T.op) s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm") f = tvm.build(s, [A, B, T], "llvm")
...@@ -361,7 +352,6 @@ def test_out_of_bounds_tensors_with_diff_shapes1D_llvm(a_shape, b_shape, c_shape ...@@ -361,7 +352,6 @@ def test_out_of_bounds_tensors_with_diff_shapes1D_llvm(a_shape, b_shape, c_shape
T = te.compute((m, ), lambda i: A[i]*B[i]) T = te.compute((m, ), lambda i: A[i]*B[i])
s = te.create_schedule(T.op) s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm") f = tvm.build(s, [A, B, T], "llvm")
...@@ -380,7 +370,6 @@ def test_in_bounds_tensors_with_same_shapes2D_llvm(): ...@@ -380,7 +370,6 @@ def test_in_bounds_tensors_with_same_shapes2D_llvm():
T = te.compute((m, m), lambda i, j: A[i][j]*B[i][j]) T = te.compute((m, m), lambda i, j: A[i][j]*B[i][j])
s = te.create_schedule(T.op) s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm") f = tvm.build(s, [A, B, T], "llvm")
...@@ -400,7 +389,6 @@ def test_out_of_bounds_tensors_with_diff_shapes2D_llvm(a_shape, b_shape, c_shape ...@@ -400,7 +389,6 @@ def test_out_of_bounds_tensors_with_diff_shapes2D_llvm(a_shape, b_shape, c_shape
T = te.compute((m, m), lambda i, j: A[i][j]*B[i][j]) T = te.compute((m, m), lambda i, j: A[i][j]*B[i][j])
s = te.create_schedule(T.op) s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm") f = tvm.build(s, [A, B, T], "llvm")
...@@ -419,7 +407,7 @@ def test_in_bounds_tensors_with_same_shapes3D_llvm(): ...@@ -419,7 +407,7 @@ def test_in_bounds_tensors_with_same_shapes3D_llvm():
T = te.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p]) T = te.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p])
s = te.create_schedule(T.op) s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm") f = tvm.build(s, [A, B, T], "llvm")
...@@ -439,7 +427,7 @@ def test_out_of_bounds_tensors_with_diff_shapes3D_llvm(a_shape, b_shape, c_shape ...@@ -439,7 +427,7 @@ def test_out_of_bounds_tensors_with_diff_shapes3D_llvm(a_shape, b_shape, c_shape
T = te.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p]) T = te.compute((m, m, m), lambda i, j, p: A[i][j][p]*B[i][j][p])
s = te.create_schedule(T.op) s = te.create_schedule(T.op)
lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False) lowered_func = tvm.lower (s, [A, B, T], "llvm", simple_mode=False)
print (lowered_func.body)
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
f = tvm.build(s, [A, B, T], "llvm") f = tvm.build(s, [A, B, T], "llvm")
...@@ -460,7 +448,7 @@ def test_out_of_bounds_tensors_with_zero_shape_op_with_not_zero_shape_llvm(): ...@@ -460,7 +448,7 @@ def test_out_of_bounds_tensors_with_zero_shape_op_with_not_zero_shape_llvm():
D = te.compute((), lambda : C + 1) D = te.compute((), lambda : C + 1)
s = te.create_schedule(D.op) s = te.create_schedule(D.op)
stmt = tvm.lower (s, [A, scale, D], simple_mode=True) stmt = tvm.lower (s, [A, scale, D], simple_mode=True)
print (stmt)
# build and invoke the kernel. # build and invoke the kernel.
f = tvm.build(s, [A, scale, D], "llvm") f = tvm.build(s, [A, scale, D], "llvm")
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
......
...@@ -40,8 +40,7 @@ def test_double_buffer(): ...@@ -40,8 +40,7 @@ def test_double_buffer():
stmt = tvm.tir.ir_pass.Simplify(stmt) stmt = tvm.tir.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.tir.Allocate) assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2 assert stmt.body.body.extents[0].value == 2
f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 2, True)
mod = tvm.testing.LoweredFuncsToIRModule([f])
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0] count = [0]
......
...@@ -381,7 +381,7 @@ def test_multilevel_splitting_with_indivisble_factors(): ...@@ -381,7 +381,7 @@ def test_multilevel_splitting_with_indivisble_factors():
## But this does the right thing. ## But this does the right thing.
with tvm.target.build_config(partition_const_loop=True): with tvm.target.build_config(partition_const_loop=True):
lowered_body = tvm.lower(s, [A, B]).body lowered_body = tvm.lower(s, [A, B], name="x")["x"].body
def visit_stmt(op): def visit_stmt(op):
return(isinstance(op, tvm.tir.Max)) return(isinstance(op, tvm.tir.Max))
num_max = collect_visit(lowered_body, visit_stmt) num_max = collect_visit(lowered_body, visit_stmt)
...@@ -407,7 +407,7 @@ def test_double_splitting_with_indivisible_factors(): ...@@ -407,7 +407,7 @@ def test_double_splitting_with_indivisible_factors():
# Find the beginning of the Halide IR corresponding to kernel code # Find the beginning of the Halide IR corresponding to kernel code
# and make sure it doesn't have an if statements left # and make sure it doesn't have an if statements left
top_produce = find_top_produce(f.body) top_produce = find_top_produce(f["fadd1"].body)
assert(not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse)))) assert(not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse))))
# check functional correctness of generated code # check functional correctness of generated code
......
...@@ -92,9 +92,7 @@ def test_flatten_double_buffer(): ...@@ -92,9 +92,7 @@ def test_flatten_double_buffer():
stmt = tvm.tir.ir_pass.Simplify(stmt) stmt = tvm.tir.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.tir.Allocate) assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2 assert stmt.body.body.extents[0].value == 2
f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True) mod = tvm.testing.MakeAPILegacy(stmt, "db", [A.asobject(), C.asobject()], 2, True)
f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2, True)
mod = tvm.testing.LoweredFuncsToIRModule([f])
f = tvm.tir.transform.ThreadSync("shared")(mod)["db"] f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
count = [0] count = [0]
......
...@@ -36,12 +36,7 @@ def test_for(): ...@@ -36,12 +36,7 @@ def test_for():
ib.emit(tvm.tir.call_extern ib.emit(tvm.tir.call_extern
("int32", "fadd", device_context(0), A)) ("int32", "fadd", device_context(0), A))
body = ib.get() body = ib.get()
f = tvm.tir.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True) mod = tvm.testing.MakeAPILegacy(body, "func", [dev_type, n], 2, True)
# temp adapter to convert loweredFunc to IRModule
# to test passes in the new style.x
mod = tvm.testing.LoweredFuncsToIRModule([f])
mod = tvm.tir.transform.CombineContextCall()(mod) mod = tvm.tir.transform.CombineContextCall()(mod)
assert mod["func"].body.value.dtype == "handle" assert mod["func"].body.value.dtype == "handle"
......
...@@ -35,10 +35,8 @@ def test_lower_warp_mem(): ...@@ -35,10 +35,8 @@ def test_lower_warp_mem():
cuda_target = tvm.target.create("cuda") cuda_target = tvm.target.create("cuda")
assert cuda_target.thread_warp_size == 32 assert cuda_target.thread_warp_size == 32
f = tvm.lower(s, [A, B], name="f") mod = tvm.lower(s, [A, B], name="f")
mod = tvm.testing.LoweredFuncsToIRModule([f])
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"] fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
mod = tvm.IRModule.from_expr(fdevice) mod = tvm.IRModule.from_expr(fdevice)
......
...@@ -35,11 +35,11 @@ def test_makeapi(): ...@@ -35,11 +35,11 @@ def test_makeapi():
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64) stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64)
num_unpacked_args = 2 num_unpacked_args = 2
f = tvm.tir.ir_pass.MakeAPI( f = tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr(
stmt, "myadd", [n, Ab, Bb, Cb], num_unpacked_args, True) "tir.no_alias", True).with_attr("global_symbol", tvm.runtime.String("myadd"))
assert(f.handle_data_type[Ab.data].dtype == Ab.dtype) mod = tvm.IRModule.from_expr(f)
assert(len(f.args) == 7) f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"]
output_ssa = False assert(len(f.params) == 7)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -37,10 +37,9 @@ def test_thread_storage_sync(): ...@@ -37,10 +37,9 @@ def test_thread_storage_sync():
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A') Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2') A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64) stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
cuda_target = tvm.target.create("cuda")
mod = tvm.testing.LoweredFuncsToIRModule([f]) cuda_target = tvm.target.create("cuda")
mod = tvm.testing.MakeAPILegacy(stmt, "test", [Ab, A2b], 0, True)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod) mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", cuda_target))(mod)
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"] fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
mod = tvm.IRModule.from_expr(fdevice) mod = tvm.IRModule.from_expr(fdevice)
......
...@@ -36,7 +36,7 @@ Before reading this tutorial, we assume readers have already known these topics ...@@ -36,7 +36,7 @@ Before reading this tutorial, we assume readers have already known these topics
- Visitor design pattern. Otherwise, check the - Visitor design pattern. Otherwise, check the
`Python AST module <https://docs.python.org/3/library/ast.html>`_ to see how an AST `Python AST module <https://docs.python.org/3/library/ast.html>`_ to see how an AST
visitor is implemented. visitor is implemented.
- How a HalideIR/Schedule is lowered to either a LoweredFunc class or a LLVM module. Otherwise, - How a Schedule is lowered to either an IRModule class or a LLVM module. Otherwise,
take a look at ``python/tvm/build_module.py`` to get some basics. take a look at ``python/tvm/build_module.py`` to get some basics.
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment