Unverified Commit e4b80bda by Tianqi Chen Committed by GitHub

[IR][TRANSFORM] Enable CopyOnWrite for passes. (#5309)

This PR enables the copy on write optimizations passes:
- Enable COW for IRModule both TIR and relay passes.
- Enabled COW for PrimFunc in TIR passes.

Need more thoughts into whether/how to enable COW
for relay::Function, due to some function passes depend
on the presence of IRModule for context information,
and the std::move of the related function to nullptr
might affect the related behavior.
parent 5b37d4c1
...@@ -124,7 +124,7 @@ class PrimExpr : public BaseExpr { ...@@ -124,7 +124,7 @@ class PrimExpr : public BaseExpr {
private: private:
// Internal function for conversion. // Internal function for conversion.
friend struct runtime::PackedFuncValueConverter<PrimExpr>; friend struct runtime::PackedFuncValueConverter<PrimExpr>;
TVM_DLL static PrimExpr FromObject_(ObjectPtr<Object> ptr); TVM_DLL static PrimExpr FromObject_(ObjectRef ref);
}; };
/*! /*!
...@@ -464,9 +464,8 @@ struct PackedFuncValueConverter<PrimExpr> { ...@@ -464,9 +464,8 @@ struct PackedFuncValueConverter<PrimExpr> {
if (val.type_code() == kDLFloat) { if (val.type_code() == kDLFloat) {
return PrimExpr(static_cast<float>(val.operator double())); return PrimExpr(static_cast<float>(val.operator double()));
} }
TVM_CHECK_TYPE_CODE(val.type_code(), kTVMObjectHandle);
Object* ptr = val.ptr<Object>(); return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
return PrimExpr::FromObject_(GetObjectPtr<Object>(ptr));
} }
}; };
} // namespace runtime } // namespace runtime
......
...@@ -62,6 +62,7 @@ ...@@ -62,6 +62,7 @@
#include <tvm/ir/error.h> #include <tvm/ir/error.h>
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <string> #include <string>
#include <utility>
namespace tvm { namespace tvm {
namespace transform { namespace transform {
...@@ -251,8 +252,8 @@ class PassNode : public Object { ...@@ -251,8 +252,8 @@ class PassNode : public Object {
* *
* \return The transformed module. * \return The transformed module.
*/ */
IRModule operator()(const IRModule& mod) const { IRModule operator()(IRModule mod) const {
return this->operator()(mod, PassContext::Current()); return this->operator()(std::move(mod), PassContext::Current());
} }
/*! /*!
...@@ -263,7 +264,7 @@ class PassNode : public Object { ...@@ -263,7 +264,7 @@ class PassNode : public Object {
* *
* \return The transformed module. * \return The transformed module.
*/ */
virtual IRModule operator()(const IRModule& mod, virtual IRModule operator()(IRModule mod,
const PassContext& pass_ctx) const = 0; const PassContext& pass_ctx) const = 0;
void VisitAttrs(AttrVisitor* v) {} void VisitAttrs(AttrVisitor* v) {}
...@@ -277,14 +278,22 @@ class Pass : public ObjectRef { ...@@ -277,14 +278,22 @@ class Pass : public ObjectRef {
/*! /*!
* \brief Transform mod using the default PassContext in the current scope. * \brief Transform mod using the default PassContext in the current scope.
* *
* \code
*
* // If you do no longer need the input module
* // it is recommended to use std::move to move your input module.
* mod = pass(std::move(mod));
*
* \endcode
*
* \param mod The module that an optimization pass runs on. * \param mod The module that an optimization pass runs on.
* *
* \return The transformed module. * \return The transformed module.
*/ */
IRModule operator()(const IRModule& mod) const { IRModule operator()(IRModule mod) const {
const PassNode* node = operator->(); const PassNode* node = operator->();
CHECK(node != nullptr); CHECK(node != nullptr);
return node->operator()(mod); return node->operator()(std::move(mod));
} }
/*! /*!
* \brief Transform mod using a functor under a given pass context. * \brief Transform mod using a functor under a given pass context.
...@@ -294,11 +303,11 @@ class Pass : public ObjectRef { ...@@ -294,11 +303,11 @@ class Pass : public ObjectRef {
* *
* \return The transformed module. * \return The transformed module.
*/ */
IRModule operator()(const IRModule& mod, IRModule operator()(IRModule mod,
const PassContext& pass_ctx) const { const PassContext& pass_ctx) const {
const PassNode* node = operator->(); const PassNode* node = operator->();
CHECK(node != nullptr); CHECK(node != nullptr);
return node->operator()(mod, pass_ctx); return node->operator()(std::move(mod), pass_ctx);
} }
TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode); TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode);
......
...@@ -325,6 +325,7 @@ inline const char* TypeCode2Str(int type_code) { ...@@ -325,6 +325,7 @@ inline const char* TypeCode2Str(int type_code) {
case kTVMModuleHandle: return "ModuleHandle"; case kTVMModuleHandle: return "ModuleHandle";
case kTVMNDArrayHandle: return "NDArrayContainer"; case kTVMNDArrayHandle: return "NDArrayContainer";
case kTVMObjectHandle: return "Object"; case kTVMObjectHandle: return "Object";
case kTVMObjectRValueRefArg: return "ObjectRValueRefArg";
default: LOG(FATAL) << "unknown type_code=" default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return ""; << static_cast<int>(type_code); return "";
} }
......
...@@ -885,7 +885,7 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) { ...@@ -885,7 +885,7 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {
template <typename SubRef, typename BaseRef> template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref) { inline SubRef Downcast(BaseRef ref) {
CHECK(ref->template IsInstance<typename SubRef::ContainerType>()) CHECK(!ref.defined() || ref->template IsInstance<typename SubRef::ContainerType>())
<< "Downcast from " << ref->GetTypeKey() << " to " << "Downcast from " << ref->GetTypeKey() << " to "
<< SubRef::ContainerType::_type_key << " failed."; << SubRef::ContainerType::_type_key << " failed.";
return SubRef(std::move(ref.data_)); return SubRef(std::move(ref.data_));
......
...@@ -1357,7 +1357,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { ...@@ -1357,7 +1357,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
ptr->IsInstance<Module::ContainerType>())) { ptr->IsInstance<Module::ContainerType>())) {
values_[i].v_handle = ptr; values_[i].v_handle = ptr;
type_codes_[i] = kTVMModuleHandle; type_codes_[i] = kTVMModuleHandle;
} else if (std::is_rvalue_reference<T>::value) { } else if (std::is_rvalue_reference<decltype(value)>::value) {
values_[i].v_handle = const_cast<Object**>(&(value.data_.data_)); values_[i].v_handle = const_cast<Object**>(&(value.data_.data_));
type_codes_[i] = kTVMObjectRValueRefArg; type_codes_[i] = kTVMObjectRValueRefArg;
} else { } else {
......
...@@ -197,11 +197,12 @@ TVM_DLL Pass CombineContextCall(); ...@@ -197,11 +197,12 @@ TVM_DLL Pass CombineContextCall();
/*! /*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits. * \brief Narrow down PrimExpr datatype in stmt to target_bits.
* *
* \note Run this pass after StorageFlatten. * \param target_bits The target bits
* *
* \note Run this pass after storage flatten.
* \return The pass. * \return The pass.
*/ */
TVM_DLL Pass NarrowDataType(); TVM_DLL Pass NarrowDataType(int target_bits);
} // namespace transform } // namespace transform
} // namespace tir } // namespace tir
......
...@@ -54,6 +54,7 @@ class InternalError(TVMError): ...@@ -54,6 +54,7 @@ class InternalError(TVMError):
register_error("ValueError", ValueError) register_error("ValueError", ValueError)
register_error("TypeError", TypeError) register_error("TypeError", TypeError)
register_error("AttributeError", AttributeError) register_error("AttributeError", AttributeError)
register_error("KeyError", KeyError)
@register_error @register_error
......
...@@ -38,7 +38,7 @@ def Apply(ftransform): ...@@ -38,7 +38,7 @@ def Apply(ftransform):
# pylint: disable=unused-argument # pylint: disable=unused-argument
def _transform(func, mod, ctx): def _transform(func, mod, ctx):
return ftransform(func) return ftransform(func)
return _fpass.prim_func_pass(_transform, opt_level=0) return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply")
def Filter(fcond): def Filter(fcond):
...@@ -57,7 +57,7 @@ def Filter(fcond): ...@@ -57,7 +57,7 @@ def Filter(fcond):
# pylint: disable=unused-argument # pylint: disable=unused-argument
def _transform(func, mod, ctx): def _transform(func, mod, ctx):
return func if fcond(func) else None return func if fcond(func) else None
return _fpass.prim_func_pass(_transform, opt_level=0) return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter")
def LowerCustomDatatypes(): def LowerCustomDatatypes():
...@@ -221,9 +221,14 @@ def CombineContextCall(): ...@@ -221,9 +221,14 @@ def CombineContextCall():
return _ffi_api.CombineContextCall() return _ffi_api.CombineContextCall()
def NarrowDataType(): def NarrowDataType(target_bits):
"""Narrow down PrimExpr datatype in stmt to target_bits. """Narrow down PrimExpr datatype in stmt to target_bits.
Parameters
----------
target_bits : int
The target bit configuration.
Returns Returns
------- -------
fpass : tvm.ir.transform.Pass fpass : tvm.ir.transform.Pass
...@@ -233,4 +238,4 @@ def NarrowDataType(): ...@@ -233,4 +238,4 @@ def NarrowDataType():
---- ----
Run this pass after StorageFlatten. Run this pass after StorageFlatten.
""" """
return _ffi_api.NarrowDataType() return _ffi_api.NarrowDataType(target_bits)
...@@ -43,21 +43,21 @@ PrimExpr::PrimExpr(float value) ...@@ -43,21 +43,21 @@ PrimExpr::PrimExpr(float value)
PrimExpr::PrimExpr(runtime::String value) PrimExpr::PrimExpr(runtime::String value)
: PrimExpr(tir::StringImmNode::make(value)) {} : PrimExpr(tir::StringImmNode::make(value)) {}
PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) { PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
using runtime::ObjectTypeChecker; using runtime::ObjectTypeChecker;
if (ptr->IsInstance<tir::IterVarNode>()) { if (auto* ptr = ref.as<tir::IterVarNode>()) {
return tir::IterVar(ptr)->var; return GetRef<tir::IterVar>(ptr)->var;
} }
if (ptr->IsInstance<te::TensorNode>()) { if (auto* ptr = ref.as<te::TensorNode>()) {
return te::Tensor(ptr)(); return GetRef<te::Tensor>(ptr)();
} }
if (ptr->IsInstance<runtime::StringObj>()) { if (auto* ptr = ref.as<runtime::StringObj>()) {
return tir::StringImmNode::make(runtime::String(ptr)); return tir::StringImmNode::make(GetRef<runtime::String>(ptr));
} }
CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr.get())) CHECK(ObjectTypeChecker<PrimExpr>::Check(ref.get()))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName() << "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey(); << " but get " << ref->GetTypeKey();
return PrimExpr(ptr); return Downcast<PrimExpr>(ref);
} }
......
...@@ -121,8 +121,20 @@ bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const { ...@@ -121,8 +121,20 @@ bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const {
GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const { GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const {
auto it = global_var_map_.find(name); auto it = global_var_map_.find(name);
CHECK(it != global_var_map_.end()) if (it == global_var_map_.end()) {
<< "Cannot find global var " << name << " in the Module"; std::ostringstream msg;
msg << "ValueError: Cannot find global var \"" << name << "\" in the Module\n"
<< "candidates are: [";
int counter = 0;
for (auto kv : global_var_map_) {
if (counter++ != 0) {
msg << ", ";
}
msg << "\"" << kv.first << "\"";
}
msg << "]";
LOG(FATAL) << msg.str();
}
return (*it).second; return (*it).second;
} }
......
...@@ -126,7 +126,7 @@ class ModulePassNode : public PassNode { ...@@ -126,7 +126,7 @@ class ModulePassNode : public PassNode {
* *
* \return Return the updated module. * \return Return the updated module.
*/ */
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final; IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
/*! /*!
* \brief Get the pass information/meta data. * \brief Get the pass information/meta data.
...@@ -205,7 +205,7 @@ class SequentialNode : public PassNode { ...@@ -205,7 +205,7 @@ class SequentialNode : public PassNode {
* *
* \return Return the updated module. * \return Return the updated module.
*/ */
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final; IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
static constexpr const char* _type_key = "transform.Sequential"; static constexpr const char* _type_key = "transform.Sequential";
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
...@@ -231,19 +231,20 @@ ModulePass::ModulePass( ...@@ -231,19 +231,20 @@ ModulePass::ModulePass(
} }
// Module -> Module optimizations. // Module -> Module optimizations.
IRModule ModulePassNode::operator()(const IRModule& mod, IRModule ModulePassNode::operator()(IRModule mod,
const PassContext& pass_ctx) const { const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info(); const PassInfo& pass_info = Info();
DLOG(INFO) << "Executing module pass : " DLOG(INFO) << "Executing module pass : "
<< pass_info->name << pass_info->name
<< " with opt level: " << " with opt level: "
<< pass_info->opt_level; << pass_info->opt_level;
CHECK(mod.defined()); CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, true); pass_ctx.Trace(mod, pass_info, true);
IRModule updated_mod = pass_func(mod, pass_ctx); mod = pass_func(std::move(mod), pass_ctx);
CHECK(updated_mod.defined()); CHECK(mod.defined());
pass_ctx.Trace(updated_mod, pass_info, false); pass_ctx.Trace(mod, pass_info, false);
return updated_mod; return mod;
} }
Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) { Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
...@@ -314,18 +315,17 @@ Pass GetPass(const std::string& pass_name) { ...@@ -314,18 +315,17 @@ Pass GetPass(const std::string& pass_name) {
// TODO(zhiics): we currenlty only sequentially execute each pass in // TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase // a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future. // ordering problem needs to be handled in the future.
IRModule SequentialNode::operator()(const IRModule& module, IRModule SequentialNode::operator()(IRModule mod,
const PassContext& pass_ctx) const { const PassContext& pass_ctx) const {
IRModule mod = module;
for (const Pass& pass : passes) { for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization."; CHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info(); const PassInfo& pass_info = pass->Info();
if (!PassEnabled(pass_info)) continue; if (!PassEnabled(pass_info)) continue;
// resolve dependencies // resolve dependencies
for (const auto& it : pass_info->required) { for (const auto& it : pass_info->required) {
mod = GetPass(it)(mod, pass_ctx); mod = GetPass(it)(std::move(mod), pass_ctx);
} }
mod = pass(mod, pass_ctx); mod = pass(std::move(mod), pass_ctx);
} }
return mod; return mod;
} }
...@@ -375,11 +375,8 @@ TVM_REGISTER_GLOBAL("transform.MakeModulePass") ...@@ -375,11 +375,8 @@ TVM_REGISTER_GLOBAL("transform.MakeModulePass")
}); });
TVM_REGISTER_GLOBAL("transform.RunPass") TVM_REGISTER_GLOBAL("transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed([](Pass pass, IRModule mod) {
Pass pass = args[0]; return pass(std::move(mod));
IRModule mod = args[1];
ObjectRef ref = args[1];
*ret = pass(mod);
}); });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <tvm/runtime/container.h> #include <tvm/runtime/container.h>
#include <tvm/node/container.h> #include <tvm/node/container.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <cstring> #include "../support/str_escape.h"
namespace tvm { namespace tvm {
...@@ -63,6 +63,13 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait) ...@@ -63,6 +63,13 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait)
static_cast<const runtime::StringObj*>(n)).operator std::string(); static_cast<const runtime::StringObj*>(n)).operator std::string();
}); });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<runtime::StringObj>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const runtime::StringObj*>(node.get());
p->stream << '"' << support::StrEscape(op->data, op->size) << '"';
});
struct ADTObjTrait { struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr; static constexpr const std::nullptr_t VisitAttrs = nullptr;
......
...@@ -68,7 +68,7 @@ class FunctionPassNode : public PassNode { ...@@ -68,7 +68,7 @@ class FunctionPassNode : public PassNode {
* *
* \return Return the updated module. * \return Return the updated module.
*/ */
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final; IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
/*! /*!
* \brief Get the pass information/meta data. * \brief Get the pass information/meta data.
...@@ -113,7 +113,7 @@ FunctionPass::FunctionPass( ...@@ -113,7 +113,7 @@ FunctionPass::FunctionPass(
} }
// Perform Module -> Module optimizations at the Function level. // Perform Module -> Module optimizations at the Function level.
IRModule FunctionPassNode::operator()(const IRModule& mod, IRModule FunctionPassNode::operator()(IRModule mod,
const PassContext& pass_ctx) const { const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info(); const PassInfo& pass_info = Info();
CHECK(mod.defined()); CHECK(mod.defined());
...@@ -122,6 +122,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, ...@@ -122,6 +122,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
<< " with opt level: " << " with opt level: "
<< pass_info->opt_level; << pass_info->opt_level;
pass_ctx.Trace(mod, pass_info, true); pass_ctx.Trace(mod, pass_info, true);
// Execute the pass function and return a new module. // Execute the pass function and return a new module.
IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports()); IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates; std::vector<std::pair<GlobalVar, Function> > updates;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \file support/str_escape.h
* \brief Print escape sequence of a string.
*/
#ifndef TVM_SUPPORT_STR_ESCAPE_H_
#define TVM_SUPPORT_STR_ESCAPE_H_
#include <string>
#include <sstream>
namespace tvm {
namespace support {
/*!
* \brief Create a stream with escape.
* \param data The data
* \param size The size of the string.
* \return the Result string.
*/
inline std::string StrEscape(const char* data, size_t size) {
std::ostringstream stream;
for (size_t i = 0; i < size; ++i) {
unsigned char c = data[i];
if (c >= ' ' && c <= '~' && c != '\\' && c != '"') {
stream << c;
} else {
stream << '\\';
switch (c) {
case '"':
stream << '"';
break;
case '\\':
stream << '\\';
break;
case '\t':
stream << 't';
break;
case '\r':
stream << 'r';
break;
case '\n':
stream << 'n';
break;
default:
const char* hex_digits = "0123456789ABCDEF";
stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf];
}
}
}
return stream.str();
}
/*!
* \brief Create a stream with escape.
* \param data The data
* \param size The size of the string.
* \return the Result string.
*/
inline std::string StrEscape(const std::string& val) {
return StrEscape(val.data(), val.length());
}
} // namespace support
} // namespace tvm
#endif // TVM_SUPPORT_STR_ESCAPE_H_
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <memory> #include <memory>
#include <limits> #include <limits>
#include "../pass/ir_util.h" #include "../pass/ir_util.h"
#include "../../support/str_escape.h"
namespace tvm { namespace tvm {
namespace tir { namespace tir {
...@@ -425,38 +426,8 @@ TVM_REGISTER_NODE_TYPE(BufferLoadNode); ...@@ -425,38 +426,8 @@ TVM_REGISTER_NODE_TYPE(BufferLoadNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StringImmNode>([](const ObjectRef& node, ReprPrinter* p) { .set_dispatch<StringImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const StringImmNode*>(node.get()); auto* op = static_cast<const StringImmNode*>(node.get());
auto& stream = p->stream; p->stream << '\"' << support::StrEscape(op->value) << '\"';
stream << '"'; });
for (size_t i = 0; i < op->value.size(); ++i) {
unsigned char c = op->value[i];
if (c >= ' ' && c <= '~' && c != '\\' && c != '"') {
stream << c;
} else {
stream << '\\';
switch (c) {
case '"':
stream << '"';
break;
case '\\':
stream << '\\';
break;
case '\t':
stream << 't';
break;
case '\r':
stream << 'r';
break;
case '\n':
stream << 'n';
break;
default:
const char* hex_digits = "0123456789ABCDEF";
stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf];
}
}
}
stream << '"';
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CastNode>([](const ObjectRef& node, ReprPrinter* p) { .set_dispatch<CastNode>([](const ObjectRef& node, ReprPrinter* p) {
......
...@@ -55,7 +55,7 @@ class PrimFuncPassNode : public PassNode { ...@@ -55,7 +55,7 @@ class PrimFuncPassNode : public PassNode {
* *
* \return Return the updated module. * \return Return the updated module.
*/ */
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final; IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
/*! /*!
* \brief Get the pass information/meta data. * \brief Get the pass information/meta data.
...@@ -90,34 +90,35 @@ PrimFuncPass::PrimFuncPass( ...@@ -90,34 +90,35 @@ PrimFuncPass::PrimFuncPass(
} }
// Perform Module -> Module optimizations at the PrimFunc level. // Perform Module -> Module optimizations at the PrimFunc level.
IRModule PrimFuncPassNode::operator()(const IRModule& mod, IRModule PrimFuncPassNode::operator()(IRModule mod,
const PassContext& pass_ctx) const { const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info(); const PassInfo& pass_info = Info();
CHECK(mod.defined()); CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, true); pass_ctx.Trace(mod, pass_info, true);
// Execute the pass function and return a new module. std::vector<ObjectRef> deleted_list;
IRModule updated_mod = IRModule( IRModuleNode* mod_ptr = mod.CopyOnWrite();
mod->functions, mod->type_definitions, mod->Imports()); auto* func_dict = mod_ptr->functions.CopyOnWrite();
std::vector<std::pair<GlobalVar, PrimFunc> > updates; // directly loop over the underlying dict
for (const auto& it : updated_mod->functions) { for (auto& kv : func_dict->data) {
// only picks up relay::PrimFunc // only picks up tir::PrimFunc
if (auto* n = it.second.as<PrimFuncNode>()) { if (kv.second->IsInstance<PrimFuncNode>()) {
PrimFunc func = GetRef<PrimFunc>(n); // move out the function so that it is the only copy.
auto updated_func = PrimFunc func = Downcast<PrimFunc>(std::move(kv.second));
pass_func(func, updated_mod, pass_ctx); func = pass_func(std::move(func), mod, pass_ctx);
updates.push_back({it.first, updated_func}); kv.second = std::move(func);
if (!kv.second.defined()) {
deleted_list.push_back(kv.first);
} }
} }
// automatic removal of None
for (const auto& pair : updates) {
if (pair.second.defined()) {
updated_mod->Add(pair.first, pair.second, true);
} else {
updated_mod->Remove(pair.first);
} }
// automatic removal of None
for (const auto& gv : deleted_list) {
func_dict->data.erase(gv);
} }
pass_ctx.Trace(updated_mod, pass_info, false); pass_ctx.Trace(mod, pass_info, false);
return updated_mod; return mod;
} }
Pass CreatePrimFuncPass( Pass CreatePrimFuncPass(
......
...@@ -397,17 +397,14 @@ Stmt NarrowDataType(Stmt stmt, int target_bits) { ...@@ -397,17 +397,14 @@ Stmt NarrowDataType(Stmt stmt, int target_bits) {
namespace transform { namespace transform {
Pass NarrowDataType() { Pass NarrowDataType(int target_bits) {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [target_bits](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite(); auto* n = f.CopyOnWrite();
IntImm target_bits = f->GetAttr<IntImm>("target_bits"); n->body = DataTypeRewriter(target_bits)(std::move(n->body));
CHECK(target_bits.defined())
<< "NarrowDataType: Require the target_bits";
n->body = DataTypeRewriter(target_bits->value)(std::move(n->body));
return f; return f;
}; };
return CreatePrimFuncPass( return CreatePrimFuncPass(
pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {}); pass_func, 0, "tir.NarrowDataType", {});
} }
TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType") TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType")
......
...@@ -173,7 +173,7 @@ Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) { ...@@ -173,7 +173,7 @@ Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
class HostDeviceSplitter : public StmtMutator { class HostDeviceSplitter : public StmtMutator {
public: public:
explicit HostDeviceSplitter(IRModuleNode* device_mod, explicit HostDeviceSplitter(IRModule* device_mod,
Target device_target, Target device_target,
std::string name_prefix) std::string name_prefix)
: device_mod_(device_mod), : device_mod_(device_mod),
...@@ -240,7 +240,7 @@ class HostDeviceSplitter : public StmtMutator { ...@@ -240,7 +240,7 @@ class HostDeviceSplitter : public StmtMutator {
runtime::String(kernel_symbol)); runtime::String(kernel_symbol));
device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1)); device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1));
device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_); device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_);
device_mod_->Add(GlobalVar(kernel_symbol), device_func); (*device_mod_)->Add(GlobalVar(kernel_symbol), device_func);
// generate calls to the device function // generate calls to the device function
Array<PrimExpr> call_args; Array<PrimExpr> call_args;
...@@ -257,7 +257,7 @@ class HostDeviceSplitter : public StmtMutator { ...@@ -257,7 +257,7 @@ class HostDeviceSplitter : public StmtMutator {
} }
// target ir module // target ir module
IRModuleNode* device_mod_; IRModule* device_mod_;
// Device target // Device target
Target device_target_; Target device_target_;
// function name hint // function name hint
...@@ -268,7 +268,7 @@ class HostDeviceSplitter : public StmtMutator { ...@@ -268,7 +268,7 @@ class HostDeviceSplitter : public StmtMutator {
}; };
PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) { PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget); auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined()) CHECK(target.defined())
<< "SplitHostDevice: Require the target attribute"; << "SplitHostDevice: Require the target attribute";
...@@ -287,26 +287,22 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) { ...@@ -287,26 +287,22 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) {
} }
namespace transform { namespace transform {
Pass SplitHostDevice() { Pass SplitHostDevice() {
auto pass_func = [](IRModule m, PassContext ctx) { auto pass_func = [](IRModule mod, PassContext ctx) {
IRModuleNode* mptr = m.CopyOnWrite(); IRModuleNode* mod_ptr = mod.CopyOnWrite();
std::vector<std::pair<GlobalVar, PrimFunc> > updates; auto* func_dict = mod_ptr->functions.CopyOnWrite();
IRModule device_mod = IRModule::Empty();
for (const auto& kv : mptr->functions) { for (auto& kv : func_dict->data) {
if (auto* n = kv.second.as<PrimFuncNode>()) { if (kv.second->IsInstance<PrimFuncNode>()) {
PrimFunc func = GetRef<PrimFunc>(n); PrimFunc func = Downcast<PrimFunc>(std::move(kv.second));
auto updated_func = SplitHostDevice(std::move(func), mptr); kv.second = SplitHostDevice(std::move(func), &device_mod);
updates.push_back({kv.first, updated_func});
}
} }
for (const auto& pair : updates) {
mptr->Add(pair.first, pair.second, true);
} }
return m; mod->Update(device_mod);
return mod;
}; };
return tvm::transform::CreateModulePass( return tvm::transform::CreateModulePass(
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/container.h> #include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
TEST(PackedFunc, Basic) { TEST(PackedFunc, Basic) {
...@@ -274,6 +275,16 @@ TEST(TypedPackedFunc, RValue) { ...@@ -274,6 +275,16 @@ TEST(TypedPackedFunc, RValue) {
using namespace tvm; using namespace tvm;
using namespace tvm::runtime; using namespace tvm::runtime;
{ {
auto inspect = [](TVMArgs args, TVMRetValue* rv) {
for (int i = 0; i < args.size(); ++i) {
CHECK_EQ(args[0].type_code(), kTVMObjectRValueRefArg);
}
};
PackedFunc finspect(inspect);
finspect(tir::Var("x"));
}
{
auto f = [](tir::Var x, bool move) { auto f = [](tir::Var x, bool move) {
if (move) { if (move) {
CHECK(x.unique()); CHECK(x.unique());
...@@ -287,9 +298,9 @@ TEST(TypedPackedFunc, RValue) { ...@@ -287,9 +298,9 @@ TEST(TypedPackedFunc, RValue) {
tir::Var var("x"); tir::Var var("x");
CHECK(var.unique()); CHECK(var.unique());
f(var, false); tf(var, false);
// move the result to the function. // move the result to the function.
tir::Var ret = f(std::move(var), true); tir::Var ret = tf(std::move(var), true);
CHECK(!var.defined()); CHECK(!var.defined());
} }
...@@ -307,10 +318,10 @@ TEST(TypedPackedFunc, RValue) { ...@@ -307,10 +318,10 @@ TEST(TypedPackedFunc, RValue) {
tir::Var var("x"); tir::Var var("x");
CHECK(var.unique()); CHECK(var.unique());
f(var, false); tf(var, false);
f(std::move(var), true); tf(std::move(var), true);
// auto conversion. // auto conversion.
f(1, true); tf(1, true);
} }
} }
......
...@@ -20,9 +20,9 @@ from tvm.tir import const ...@@ -20,9 +20,9 @@ from tvm.tir import const
def lower_stmt(params, stmt, target_bits): def lower_stmt(params, stmt, target_bits):
func = tvm.tir.PrimFunc(params, stmt).with_attr( func = tvm.tir.PrimFunc(params, stmt)
"target_bits", target_bits) func = tvm.tir.transform.NarrowDataType(target_bits)(
func = tvm.tir.transform.NarrowDataType()(tvm.IRModule.from_expr(func))["main"] tvm.IRModule.from_expr(func))["main"]
stmt = func.body stmt = func.body
return stmt return stmt
......
...@@ -46,5 +46,25 @@ def test_prim_func_pass(): ...@@ -46,5 +46,25 @@ def test_prim_func_pass():
assert tvm.ir.structural_equal(mod["main"].body, new_func.body) assert tvm.ir.structural_equal(mod["main"].body, new_func.body)
def test_cow_pass():
def fapply(f):
assert tvm.testing.object_use_count(f) == 1
return f
pidentity = tvm.tir.transform.Apply(fapply)
x = te.var('x')
func = tvm.tir.PrimFunc(
[x], tvm.tir.Evaluate(x)).with_attr("target_bits", 32)
func_hash = func.__hash__()
mod = tvm.IRModule({"main": func})
del func
# copy on write
mod_hash = mod.__hash__()
mod = tvm.ir.transform.Sequential(
[pidentity, tvm.tir.transform.NarrowDataType(32)])(mod._move())
assert mod_hash == mod.__hash__()
assert func_hash == mod["main"].__hash__()
if __name__ == "__main__": if __name__ == "__main__":
test_cow_pass()
test_prim_func_pass() test_prim_func_pass()
...@@ -42,12 +42,5 @@ inline Array<Integer> ArrayOrInt(TVMArgValue arg) { ...@@ -42,12 +42,5 @@ inline Array<Integer> ArrayOrInt(TVMArgValue arg) {
return arg; return arg;
} }
} }
inline bool IsTensorType(TVMArgValue arg) {
return (arg.type_code() == kTVMObjectHandle &&
static_cast<Object*>(
arg.value().v_handle)->IsInstance<tvm::te::TensorNode>());
}
} // namespace topi } // namespace topi
#endif // TOPI_UTIL_H_ #endif // TOPI_UTIL_H_
...@@ -35,8 +35,8 @@ using namespace tvm::runtime; ...@@ -35,8 +35,8 @@ using namespace tvm::runtime;
#define TOPI_REGISTER_BCAST_OP(OpName, Op) \ #define TOPI_REGISTER_BCAST_OP(OpName, Op) \
TVM_REGISTER_GLOBAL(OpName) \ TVM_REGISTER_GLOBAL(OpName) \
.set_body([](TVMArgs args, TVMRetValue *rv) { \ .set_body([](TVMArgs args, TVMRetValue *rv) { \
bool lhs_is_tensor = IsTensorType(args[0]); \ bool lhs_is_tensor = args[0].IsObjectRef<tvm::te::Tensor>(); \
bool rhs_is_tensor = IsTensorType(args[1]); \ bool rhs_is_tensor = args[1].IsObjectRef<tvm::te::Tensor>(); \
if (lhs_is_tensor && rhs_is_tensor) { \ if (lhs_is_tensor && rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::te::Tensor(), \ *rv = Op(args[0].operator tvm::te::Tensor(), \
args[1].operator tvm::te::Tensor()); \ args[1].operator tvm::te::Tensor()); \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment