Unverified Commit 0a0e58bf by Tianqi Chen Committed by GitHub

[REFACTOR][TIR] Introduce PrimFuncPass. (#5139)

* [REFACTOR][TIR] Introduce PrimFuncPass.

- Introduce PrimFuncPass
- Convert one pass to the unified Pass API.

* Address comments

* Fix comments
parent bbbfc1b0
...@@ -22,3 +22,12 @@ tvm.tir ...@@ -22,3 +22,12 @@ tvm.tir
:imported-members: :imported-members:
:exclude-members: PrimExpr, const :exclude-members: PrimExpr, const
:autosummary: :autosummary:
tvm.tir.transform
-----------------
.. automodule:: tvm.tir.transform
:members:
:imported-members:
:autosummary:
...@@ -150,7 +150,7 @@ class RelayExprNode : public BaseExprNode { ...@@ -150,7 +150,7 @@ class RelayExprNode : public BaseExprNode {
/*! /*!
* \return The checked_type * \return The checked_type
*/ */
const Type& checked_type() const; inline const Type& checked_type() const;
/*! /*!
* \brief Check if the inferred(checked) type of the Expr * \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it. * is backed by a TTypeNode and return it.
......
...@@ -93,6 +93,7 @@ class TypeFunctor<R(const Type& n, Args...)> { ...@@ -93,6 +93,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT; virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const PointerTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitTypeDefault_(const Object* op, Args...) { virtual R VisitTypeDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
throw; // unreachable, written to stop compiler warning throw; // unreachable, written to stop compiler warning
...@@ -115,6 +116,7 @@ class TypeFunctor<R(const Type& n, Args...)> { ...@@ -115,6 +116,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
TVM_TYPE_FUNCTOR_DISPATCH(TypeCallNode); TVM_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeDataNode); TVM_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode); TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(PointerTypeNode);
return vtable; return vtable;
} }
}; };
...@@ -138,6 +140,7 @@ class TVM_DLL TypeVisitor : ...@@ -138,6 +140,7 @@ class TVM_DLL TypeVisitor :
void VisitType_(const TypeCallNode* op) override; void VisitType_(const TypeCallNode* op) override;
void VisitType_(const TypeDataNode* op) override; void VisitType_(const TypeDataNode* op) override;
void VisitType_(const PrimTypeNode* op) override; void VisitType_(const PrimTypeNode* op) override;
void VisitType_(const PointerTypeNode* op) override;
}; };
/*! /*!
...@@ -158,6 +161,7 @@ class TVM_DLL TypeMutator : ...@@ -158,6 +161,7 @@ class TVM_DLL TypeMutator :
Type VisitType_(const TypeCallNode* op) override; Type VisitType_(const TypeCallNode* op) override;
Type VisitType_(const TypeDataNode* op) override; Type VisitType_(const TypeDataNode* op) override;
Type VisitType_(const PrimTypeNode* op) override; Type VisitType_(const PrimTypeNode* op) override;
Type VisitType_(const PointerTypeNode* op) override;
private: private:
Array<Type> MutateArray(Array<Type> arr); Array<Type> MutateArray(Array<Type> arr);
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/tir/transform.h
* \brief TIR specific transformation passes.
*/
#ifndef TVM_TIR_TRANSFORM_H_
#define TVM_TIR_TRANSFORM_H_
#include <tvm/ir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <string>
namespace tvm {
namespace tir {
namespace transform {
using tvm::transform::Pass;
using tvm::transform::PassNode;
using tvm::transform::PassInfo;
using tvm::transform::PassInfoNode;
using tvm::transform::PassContext;
using tvm::transform::PassContextNode;
using tvm::transform::Sequential;
/*
* \brief Create a function pass that optimizes PrimFuncs.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the function pass.
* \param name The name of the function pass.
* \param required The list of the passes that the function pass is dependent on.
*
* \return The created function pass.
*/
TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);
/*!
* \brief Create PrimFuncPass to combine context calls in the host function.
*
* \return The pass.
*/
Pass CombineContextCall();
} // namespace transform
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_TRANSFORM_H_
...@@ -45,3 +45,4 @@ from .op import comm_reducer, min, max, sum ...@@ -45,3 +45,4 @@ from .op import comm_reducer, min, max, sum
from . import ir_builder from . import ir_builder
from . import ir_pass from . import ir_pass
from . import transform
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace of all TIR transformations"""
# pylint: disable=wildcard-import, invalid-name
from .function_pass import prim_func_pass, PrimFuncPass
from .transform import *
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.tir.transform"""
import tvm._ffi
tvm._ffi._init_api("tir.transform", __name__)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TIR specific function pass support."""
import inspect
import functools
import tvm._ffi
from tvm.ir.transform import Pass, PassInfo
from . import _ffi_api
@tvm._ffi.register_object("tir.PrimFuncPass")
class PrimFuncPass(Pass):
"""A pass that works on each :py:func:`tvm.tir.PrimFunc` in a module. A function
pass class should be created through py:func:`tvm.tir.transform.function_pass`.
"""
def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass"""
class PyFunctionPass(PrimFuncPass):
"""Internal wrapper class to create a class instance."""
def __init__(self, *args, **kwargs):
# initialize handle in cass pass_cls creation failed.fg
self.handle = None
inst = pass_cls(*args, **kwargs)
# it is important not to capture self to
# avoid a cyclic dependency
def _pass_func(func, mod, ctx):
return inst.transform_function(func, mod, ctx)
self.__init_handle_by_constructor__(
_ffi_api.CreatePrimFuncPass, _pass_func, pass_info)
self._inst = inst
def __getattr__(self, name):
# fall back to instance attribute if there is not any
return self._inst.__getattribute__(name)
functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__)
PyFunctionPass.__name__ = pass_cls.__name__
PyFunctionPass.__doc__ = pass_cls.__doc__
PyFunctionPass.__module__ = pass_cls.__module__
return PyFunctionPass
def prim_func_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Decorate a function pass.
This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(PrimFunc, IRModule, PassContext) -> PrimFunc]]
The transformation function or class.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the function pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the function pass is dependent on.
Returns
-------
create_function_pass : Union[Callable, FunctionPass]
A decorator will be returned if pass_func is not provided,
otherwise return the decorated result.
The returned decorator has two behaviors depending on the input:
A new FunctionPass will be returned when we decorate a pass function.
A new FunctionPass class will be returned when we decorate a class type.
Examples
--------
The following code block decorates a function pass class.
.. code-block:: python
@tvm.tir.transform.prim_func_pass(opt_level=1)
class TestReplaceFunc:
def __init__(self, new_func):
self.new_func = new_func
def transform_function(self, func, mod, ctx):
# just for demo purposes
# transform func to new_func
return self.new_func
The following code creates a function pass by decorating
a user defined transform function.
.. code-block:: python
@tvm.tir.transform.prim_func_pass(opt_level=2)
def transform(func, mod, ctx):
# my transformations here.
return func
function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
"""
if opt_level is None:
raise ValueError("Please provide opt_level for the funtion pass.")
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")
def create_function_pass(pass_arg):
"""Internal function that creates a function pass"""
fname = name if name else pass_arg.__name__
info = PassInfo(opt_level, fname, required)
if inspect.isclass(pass_arg):
return _wrap_class_function_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _ffi_api.MakeFunctionPass(pass_arg, info)
if pass_func:
return create_function_pass(pass_func)
return create_function_pass
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Wrapping existing transformations."""
# pylint: disable=invalid-name
from . import _ffi_api
def CombineContextCall():
"""Combine context calls in the host function.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.CombineContextCall()
...@@ -170,13 +170,13 @@ void IRModuleNode::Add(const GlobalVar& var, ...@@ -170,13 +170,13 @@ void IRModuleNode::Add(const GlobalVar& var,
GetRef<relay::Function>(ptr)); GetRef<relay::Function>(ptr));
} }
auto type = checked_func->checked_type(); Type type = checked_func->checked_type();
CHECK(type.as<relay::IncompleteTypeNode>() == nullptr); CHECK(type.as<relay::IncompleteTypeNode>() == nullptr);
if (functions.find(var) != functions.end()) { if (functions.find(var) != functions.end()) {
CHECK(update) CHECK(update)
<< "Already have definition for " << var->name_hint; << "Already have definition for " << var->name_hint;
auto old_type = functions[var].as<relay::FunctionNode>()->checked_type(); auto old_type = functions[var]->checked_type();
CHECK(relay::AlphaEqual(type, old_type)) CHECK(relay::AlphaEqual(type, old_type))
<< "Module#update changes type, not possible in this mode."; << "Module#update changes type, not possible in this mode.";
} }
......
...@@ -93,6 +93,10 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) { ...@@ -93,6 +93,10 @@ void TypeVisitor::VisitType_(const TypeDataNode* op) {
void TypeVisitor::VisitType_(const PrimTypeNode* op) { void TypeVisitor::VisitType_(const PrimTypeNode* op) {
} }
void TypeVisitor::VisitType_(const PointerTypeNode* op) {
this->VisitType(op->element_type);
}
Type TypeMutator::VisitType(const Type& t) { Type TypeMutator::VisitType(const Type& t) {
return t.defined() ? TypeFunctor<Type(const Type&)>::VisitType(t) : t; return t.defined() ? TypeFunctor<Type(const Type&)>::VisitType(t) : t;
} }
...@@ -209,6 +213,16 @@ Type TypeMutator::VisitType_(const PrimTypeNode* op) { ...@@ -209,6 +213,16 @@ Type TypeMutator::VisitType_(const PrimTypeNode* op) {
return GetRef<Type>(op); return GetRef<Type>(op);
} }
Type TypeMutator::VisitType_(const PointerTypeNode* op) {
Type element_type = VisitType(op->element_type);
if (element_type.same_as(op->element_type)) {
return GetRef<Type>(op);
} else {
return PointerType(element_type);
}
}
// Implements bind. // Implements bind.
class TypeBinder : public TypeMutator { class TypeBinder : public TypeMutator {
public: public:
......
...@@ -202,6 +202,22 @@ class AlphaEqualHandler: ...@@ -202,6 +202,22 @@ class AlphaEqualHandler:
return LeafObjectEqual(GetRef<ObjectRef>(lhs), other); return LeafObjectEqual(GetRef<ObjectRef>(lhs), other);
} }
bool VisitType_(const PrimTypeNode* lhs, const Type& other) final {
if (const PrimTypeNode* rhs = other.as<PrimTypeNode>()) {
return lhs->dtype == rhs->dtype;
} else {
return false;
}
}
bool VisitType_(const PointerTypeNode* lhs, const Type& other) final {
if (const PointerTypeNode* rhs = other.as<PointerTypeNode>()) {
return TypeEqual(lhs->element_type, rhs->element_type);
} else {
return false;
}
}
bool VisitType_(const TypeVarNode* lhs, const Type& other) final { bool VisitType_(const TypeVarNode* lhs, const Type& other) final {
if (const TypeVarNode* rhs = other.as<TypeVarNode>()) { if (const TypeVarNode* rhs = other.as<TypeVarNode>()) {
if (lhs->kind != rhs->kind) return false; if (lhs->kind != rhs->kind) return false;
......
...@@ -310,6 +310,9 @@ TVM_REGISTER_GLOBAL("target.Build") ...@@ -310,6 +310,9 @@ TVM_REGISTER_GLOBAL("target.Build")
} }
}); });
TVM_REGISTER_GLOBAL("testing.LoweredFuncsToIRModule")
.set_body_typed(ToIRModule);
// Export two auxiliary function to the runtime namespace. // Export two auxiliary function to the runtime namespace.
TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC") TVM_REGISTER_GLOBAL("runtime.ModulePackImportsToC")
.set_body_typed(PackImportsToC); .set_body_typed(PackImportsToC);
......
...@@ -23,10 +23,12 @@ ...@@ -23,10 +23,12 @@
*/ */
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/tir/function.h> #include <tvm/tir/function.h>
#include <tvm/tir/op.h>
namespace tvm { namespace tvm {
namespace tir { namespace tir {
// Get the function type of a PrimFunc
PrimFunc::PrimFunc(Array<tir::Var> params, PrimFunc::PrimFunc(Array<tir::Var> params,
Stmt body, Stmt body,
Type ret_type, Type ret_type,
...@@ -43,6 +45,7 @@ PrimFunc::PrimFunc(Array<tir::Var> params, ...@@ -43,6 +45,7 @@ PrimFunc::PrimFunc(Array<tir::Var> params,
n->ret_type = std::move(ret_type); n->ret_type = std::move(ret_type);
n->buffer_map = std::move(buffer_map); n->buffer_map = std::move(buffer_map);
n->attrs = std::move(attrs); n->attrs = std::move(attrs);
n->checked_type_ = n->func_type_annotation();
data_ = std::move(n); data_ = std::move(n);
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tir/ir/transform.cc
* \brief TIR specific transformation passes.
*/
#include <tvm/runtime/registry.h>
#include <tvm/node/repr_printer.h>
#include <tvm/tir/transform.h>
namespace tvm {
namespace tir {
namespace transform {
/*!
* \brief Function level pass that applies transformations to all
* TIR functions within the module.
*/
class PrimFuncPassNode : public PassNode {
public:
/* \brief The pass meta data.*/
PassInfo pass_info;
/*! \brief The pass function called on each. */
runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pass_info", &pass_info);
}
/*!
* \brief Run a function pass on given pass context.
*
* \param mod The module that an optimization pass is applied on.
* \param pass_ctx The context that an optimization pass executes on.
*
* \return Return the updated module.
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
/*!
* \brief Get the pass information/meta data.
*/
PassInfo Info() const override { return pass_info; }
static constexpr const char* _type_key = "tir.PrimFuncPass";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncPassNode, PassNode);
};
class PrimFuncPass : public Pass {
public:
/*!
* \brief The constructor
* \param pass_func The packed function which implements a pass.
* \param pass_info The pass info.
*/
TVM_DLL PrimFuncPass(
runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
PassInfo pass_info);
TVM_DEFINE_OBJECT_REF_METHODS(PrimFuncPass, Pass, PrimFuncPassNode);
};
PrimFuncPass::PrimFuncPass(
runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
PassInfo pass_info) {
auto n = make_object<PrimFuncPassNode>();
n->pass_func = std::move(pass_func);
n->pass_info = std::move(pass_info);
data_ = std::move(n);
}
// Perform Module -> Module optimizations at the PrimFunc level.
IRModule PrimFuncPassNode::operator()(const IRModule& mod,
const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, true);
// Execute the pass function and return a new module.
IRModule updated_mod = IRModule(
mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, PrimFunc> > updates;
for (const auto& it : updated_mod->functions) {
// only picks up relay::PrimFunc
if (auto* n = it.second.as<PrimFuncNode>()) {
PrimFunc func = GetRef<PrimFunc>(n);
auto updated_func =
pass_func(func, updated_mod, pass_ctx);
updates.push_back({it.first, updated_func});
}
}
for (const auto& pair : updates) {
updated_mod->Add(pair.first, pair.second, true);
}
pass_ctx.Trace(updated_mod, pass_info, false);
return updated_mod;
}
Pass CreatePrimFuncPass(
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return PrimFuncPass(pass_func, pass_info);
}
TVM_REGISTER_NODE_TYPE(PrimFuncPassNode);
TVM_REGISTER_GLOBAL("tir.transform.CreatePrimFuncPass")
.set_body_typed([](runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)> pass_func,
PassInfo pass_info) {
return PrimFuncPass(pass_func, pass_info);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrimFuncPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PrimFuncPassNode*>(ref.get());
const PassInfo info = node->Info();
p->stream << "PrimFuncPass(" << info->name
<< ", opt_level=" << info->opt_level << ")";
});
} // namespace transform
} // namespace tir
} // namespace tvm
...@@ -25,7 +25,11 @@ ...@@ -25,7 +25,11 @@
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h> #include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/ir_pass.h> #include <tvm/tir/ir_pass.h>
#include <map> #include <map>
namespace tvm { namespace tvm {
...@@ -114,5 +118,20 @@ LoweredFunc CombineContextCall(LoweredFunc f) { ...@@ -114,5 +118,20 @@ LoweredFunc CombineContextCall(LoweredFunc f) {
return LoweredFunc(n); return LoweredFunc(n);
} }
namespace transform {
Pass CombineContextCall() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = ContextCallCombiner().Combine(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "CombineContextCall", {});
}
TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall")
.set_body_typed(CombineContextCall);
} // namespace transform
} // namespace tir } // namespace tir
} // namespace tvm } // namespace tvm
...@@ -37,9 +37,15 @@ def test_for(): ...@@ -37,9 +37,15 @@ def test_for():
("int32", "fadd", device_context(0), A)) ("int32", "fadd", device_context(0), A))
body = ib.get() body = ib.get()
f = tvm.tir.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True) f = tvm.tir.ir_pass.MakeAPI(body, "func", [dev_type, n], 2, True)
f = tvm.tir.ir_pass.CombineContextCall(f)
assert f.body.value.dtype == "handle" # temp adapter to convert loweredFunc to IRModule
assert f.body.body.value.dtype == "handle" # to test passes in the new style.
mod = tvm.testing.LoweredFuncsToIRModule([f])
mod = tvm.tir.transform.CombineContextCall()(mod)
assert mod["func"].body.value.dtype == "handle"
assert mod["func"].body.body.value.dtype == "handle"
if __name__ == "__main__": if __name__ == "__main__":
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
def test_prim_func_pass():
@tvm.tir.transform.prim_func_pass(opt_level=1)
class TestReplaceFunc:
"""Simple test function to replace one argument to another."""
def __init__(self, new_func):
self.new_func = new_func
def transform_function(self, func, mod, ctx):
return self.new_func
x = te.var('x')
y = te.var('y')
b = tvm.tir.decl_buffer((x,), "float32")
stmt = tvm.tir.LetStmt(
x, 10, tvm.tir.Evaluate(x + 1));
func = tvm.tir.PrimFunc(
[x, y, b], stmt)
new_func = tvm.tir.PrimFunc(
[x, y, b], tvm.tir.Evaluate(0))
mod = tvm.IRModule({"main": func})
mod = TestReplaceFunc(new_func)(mod)
assert tvm.tir.ir_pass.Equal(mod["main"].body, new_func.body)
if __name__ == "__main__":
test_prim_func_pass()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment