Unverified Commit e4b80bda by Tianqi Chen Committed by GitHub

[IR][TRANSFORM] Enable CopyOnWrite for passes. (#5309)

This PR enables the copy on write optimizations passes:
- Enable COW for IRModule both TIR and relay passes.
- Enabled COW for PrimFunc in TIR passes.

Need more thoughts into whether/how to enable COW
for relay::Function, due to some function passes depend
on the presence of IRModule for context information,
and the std::move of the related function to nullptr
might affect the related behavior.
parent 5b37d4c1
......@@ -124,7 +124,7 @@ class PrimExpr : public BaseExpr {
private:
// Internal function for conversion.
friend struct runtime::PackedFuncValueConverter<PrimExpr>;
TVM_DLL static PrimExpr FromObject_(ObjectPtr<Object> ptr);
TVM_DLL static PrimExpr FromObject_(ObjectRef ref);
};
/*!
......@@ -464,9 +464,8 @@ struct PackedFuncValueConverter<PrimExpr> {
if (val.type_code() == kDLFloat) {
return PrimExpr(static_cast<float>(val.operator double()));
}
TVM_CHECK_TYPE_CODE(val.type_code(), kTVMObjectHandle);
Object* ptr = val.ptr<Object>();
return PrimExpr::FromObject_(GetObjectPtr<Object>(ptr));
return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
}
};
} // namespace runtime
......
......@@ -62,6 +62,7 @@
#include <tvm/ir/error.h>
#include <tvm/ir/module.h>
#include <string>
#include <utility>
namespace tvm {
namespace transform {
......@@ -251,8 +252,8 @@ class PassNode : public Object {
*
* \return The transformed module.
*/
IRModule operator()(const IRModule& mod) const {
return this->operator()(mod, PassContext::Current());
IRModule operator()(IRModule mod) const {
return this->operator()(std::move(mod), PassContext::Current());
}
/*!
......@@ -263,7 +264,7 @@ class PassNode : public Object {
*
* \return The transformed module.
*/
virtual IRModule operator()(const IRModule& mod,
virtual IRModule operator()(IRModule mod,
const PassContext& pass_ctx) const = 0;
void VisitAttrs(AttrVisitor* v) {}
......@@ -277,14 +278,22 @@ class Pass : public ObjectRef {
/*!
* \brief Transform mod using the default PassContext in the current scope.
*
* \code
*
* // If you do no longer need the input module
* // it is recommended to use std::move to move your input module.
* mod = pass(std::move(mod));
*
* \endcode
*
* \param mod The module that an optimization pass runs on.
*
* \return The transformed module.
*/
IRModule operator()(const IRModule& mod) const {
IRModule operator()(IRModule mod) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(mod);
return node->operator()(std::move(mod));
}
/*!
* \brief Transform mod using a functor under a given pass context.
......@@ -294,11 +303,11 @@ class Pass : public ObjectRef {
*
* \return The transformed module.
*/
IRModule operator()(const IRModule& mod,
IRModule operator()(IRModule mod,
const PassContext& pass_ctx) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(mod, pass_ctx);
return node->operator()(std::move(mod), pass_ctx);
}
TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode);
......
......@@ -325,6 +325,7 @@ inline const char* TypeCode2Str(int type_code) {
case kTVMModuleHandle: return "ModuleHandle";
case kTVMNDArrayHandle: return "NDArrayContainer";
case kTVMObjectHandle: return "Object";
case kTVMObjectRValueRefArg: return "ObjectRValueRefArg";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
......
......@@ -885,7 +885,7 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {
template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref) {
CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
CHECK(!ref.defined() || ref->template IsInstance<typename SubRef::ContainerType>())
<< "Downcast from " << ref->GetTypeKey() << " to "
<< SubRef::ContainerType::_type_key << " failed.";
return SubRef(std::move(ref.data_));
......
......@@ -1357,7 +1357,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
ptr->IsInstance<Module::ContainerType>())) {
values_[i].v_handle = ptr;
type_codes_[i] = kTVMModuleHandle;
} else if (std::is_rvalue_reference<T>::value) {
} else if (std::is_rvalue_reference<decltype(value)>::value) {
values_[i].v_handle = const_cast<Object**>(&(value.data_.data_));
type_codes_[i] = kTVMObjectRValueRefArg;
} else {
......
......@@ -197,11 +197,12 @@ TVM_DLL Pass CombineContextCall();
/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
*
* \note Run this pass after StorageFlatten.
* \param target_bits The target bits
*
* \note Run this pass after storage flatten.
* \return The pass.
*/
TVM_DLL Pass NarrowDataType();
TVM_DLL Pass NarrowDataType(int target_bits);
} // namespace transform
} // namespace tir
......
......@@ -54,6 +54,7 @@ class InternalError(TVMError):
register_error("ValueError", ValueError)
register_error("TypeError", TypeError)
register_error("AttributeError", AttributeError)
register_error("KeyError", KeyError)
@register_error
......
......@@ -38,7 +38,7 @@ def Apply(ftransform):
# pylint: disable=unused-argument
def _transform(func, mod, ctx):
return ftransform(func)
return _fpass.prim_func_pass(_transform, opt_level=0)
return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply")
def Filter(fcond):
......@@ -57,7 +57,7 @@ def Filter(fcond):
# pylint: disable=unused-argument
def _transform(func, mod, ctx):
return func if fcond(func) else None
return _fpass.prim_func_pass(_transform, opt_level=0)
return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter")
def LowerCustomDatatypes():
......@@ -221,9 +221,14 @@ def CombineContextCall():
return _ffi_api.CombineContextCall()
def NarrowDataType():
def NarrowDataType(target_bits):
"""Narrow down PrimExpr datatype in stmt to target_bits.
Parameters
----------
target_bits : int
The target bit configuration.
Returns
-------
fpass : tvm.ir.transform.Pass
......@@ -233,4 +238,4 @@ def NarrowDataType():
----
Run this pass after StorageFlatten.
"""
return _ffi_api.NarrowDataType()
return _ffi_api.NarrowDataType(target_bits)
......@@ -43,21 +43,21 @@ PrimExpr::PrimExpr(float value)
PrimExpr::PrimExpr(runtime::String value)
: PrimExpr(tir::StringImmNode::make(value)) {}
PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
using runtime::ObjectTypeChecker;
if (ptr->IsInstance<tir::IterVarNode>()) {
return tir::IterVar(ptr)->var;
if (auto* ptr = ref.as<tir::IterVarNode>()) {
return GetRef<tir::IterVar>(ptr)->var;
}
if (ptr->IsInstance<te::TensorNode>()) {
return te::Tensor(ptr)();
if (auto* ptr = ref.as<te::TensorNode>()) {
return GetRef<te::Tensor>(ptr)();
}
if (ptr->IsInstance<runtime::StringObj>()) {
return tir::StringImmNode::make(runtime::String(ptr));
if (auto* ptr = ref.as<runtime::StringObj>()) {
return tir::StringImmNode::make(GetRef<runtime::String>(ptr));
}
CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr.get()))
CHECK(ObjectTypeChecker<PrimExpr>::Check(ref.get()))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return PrimExpr(ptr);
<< " but get " << ref->GetTypeKey();
return Downcast<PrimExpr>(ref);
}
......
......@@ -121,8 +121,20 @@ bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const {
GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const {
auto it = global_var_map_.find(name);
CHECK(it != global_var_map_.end())
<< "Cannot find global var " << name << " in the Module";
if (it == global_var_map_.end()) {
std::ostringstream msg;
msg << "ValueError: Cannot find global var \"" << name << "\" in the Module\n"
<< "candidates are: [";
int counter = 0;
for (auto kv : global_var_map_) {
if (counter++ != 0) {
msg << ", ";
}
msg << "\"" << kv.first << "\"";
}
msg << "]";
LOG(FATAL) << msg.str();
}
return (*it).second;
}
......
......@@ -126,7 +126,7 @@ class ModulePassNode : public PassNode {
*
* \return Return the updated module.
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
/*!
* \brief Get the pass information/meta data.
......@@ -205,7 +205,7 @@ class SequentialNode : public PassNode {
*
* \return Return the updated module.
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
static constexpr const char* _type_key = "transform.Sequential";
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
......@@ -231,19 +231,20 @@ ModulePass::ModulePass(
}
// Module -> Module optimizations.
IRModule ModulePassNode::operator()(const IRModule& mod,
IRModule ModulePassNode::operator()(IRModule mod,
const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
DLOG(INFO) << "Executing module pass : "
<< pass_info->name
<< " with opt level: "
<< pass_info->opt_level;
CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, true);
IRModule updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined());
pass_ctx.Trace(updated_mod, pass_info, false);
return updated_mod;
mod = pass_func(std::move(mod), pass_ctx);
CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, false);
return mod;
}
Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
......@@ -314,18 +315,17 @@ Pass GetPass(const std::string& pass_name) {
// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future.
IRModule SequentialNode::operator()(const IRModule& module,
IRModule SequentialNode::operator()(IRModule mod,
const PassContext& pass_ctx) const {
IRModule mod = module;
for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
if (!PassEnabled(pass_info)) continue;
// resolve dependencies
for (const auto& it : pass_info->required) {
mod = GetPass(it)(mod, pass_ctx);
mod = GetPass(it)(std::move(mod), pass_ctx);
}
mod = pass(mod, pass_ctx);
mod = pass(std::move(mod), pass_ctx);
}
return mod;
}
......@@ -375,11 +375,8 @@ TVM_REGISTER_GLOBAL("transform.MakeModulePass")
});
TVM_REGISTER_GLOBAL("transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0];
IRModule mod = args[1];
ObjectRef ref = args[1];
*ret = pass(mod);
.set_body_typed([](Pass pass, IRModule mod) {
return pass(std::move(mod));
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......
......@@ -24,7 +24,7 @@
#include <tvm/runtime/container.h>
#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
#include <cstring>
#include "../support/str_escape.h"
namespace tvm {
......@@ -63,6 +63,13 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait)
static_cast<const runtime::StringObj*>(n)).operator std::string();
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<runtime::StringObj>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const runtime::StringObj*>(node.get());
p->stream << '"' << support::StrEscape(op->data, op->size) << '"';
});
struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
......
......@@ -68,7 +68,7 @@ class FunctionPassNode : public PassNode {
*
* \return Return the updated module.
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
/*!
* \brief Get the pass information/meta data.
......@@ -113,7 +113,7 @@ FunctionPass::FunctionPass(
}
// Perform Module -> Module optimizations at the Function level.
IRModule FunctionPassNode::operator()(const IRModule& mod,
IRModule FunctionPassNode::operator()(IRModule mod,
const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
CHECK(mod.defined());
......@@ -122,6 +122,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
<< " with opt level: "
<< pass_info->opt_level;
pass_ctx.Trace(mod, pass_info, true);
// Execute the pass function and return a new module.
IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \file support/str_escape.h
* \brief Print escape sequence of a string.
*/
#ifndef TVM_SUPPORT_STR_ESCAPE_H_
#define TVM_SUPPORT_STR_ESCAPE_H_
#include <string>
#include <sstream>
namespace tvm {
namespace support {
/*!
* \brief Create a stream with escape.
* \param data The data
* \param size The size of the string.
* \return the Result string.
*/
inline std::string StrEscape(const char* data, size_t size) {
std::ostringstream stream;
for (size_t i = 0; i < size; ++i) {
unsigned char c = data[i];
if (c >= ' ' && c <= '~' && c != '\\' && c != '"') {
stream << c;
} else {
stream << '\\';
switch (c) {
case '"':
stream << '"';
break;
case '\\':
stream << '\\';
break;
case '\t':
stream << 't';
break;
case '\r':
stream << 'r';
break;
case '\n':
stream << 'n';
break;
default:
const char* hex_digits = "0123456789ABCDEF";
stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf];
}
}
}
return stream.str();
}
/*!
* \brief Create a stream with escape.
* \param data The data
* \param size The size of the string.
* \return the Result string.
*/
inline std::string StrEscape(const std::string& val) {
return StrEscape(val.data(), val.length());
}
} // namespace support
} // namespace tvm
#endif // TVM_SUPPORT_STR_ESCAPE_H_
......@@ -28,6 +28,7 @@
#include <memory>
#include <limits>
#include "../pass/ir_util.h"
#include "../../support/str_escape.h"
namespace tvm {
namespace tir {
......@@ -425,38 +426,8 @@ TVM_REGISTER_NODE_TYPE(BufferLoadNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StringImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const StringImmNode*>(node.get());
auto& stream = p->stream;
stream << '"';
for (size_t i = 0; i < op->value.size(); ++i) {
unsigned char c = op->value[i];
if (c >= ' ' && c <= '~' && c != '\\' && c != '"') {
stream << c;
} else {
stream << '\\';
switch (c) {
case '"':
stream << '"';
break;
case '\\':
stream << '\\';
break;
case '\t':
stream << 't';
break;
case '\r':
stream << 'r';
break;
case '\n':
stream << 'n';
break;
default:
const char* hex_digits = "0123456789ABCDEF";
stream << 'x' << hex_digits[c >> 4] << hex_digits[c & 0xf];
}
}
}
stream << '"';
});
p->stream << '\"' << support::StrEscape(op->value) << '\"';
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CastNode>([](const ObjectRef& node, ReprPrinter* p) {
......
......@@ -55,7 +55,7 @@ class PrimFuncPassNode : public PassNode {
*
* \return Return the updated module.
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
/*!
* \brief Get the pass information/meta data.
......@@ -90,34 +90,35 @@ PrimFuncPass::PrimFuncPass(
}
// Perform Module -> Module optimizations at the PrimFunc level.
IRModule PrimFuncPassNode::operator()(const IRModule& mod,
IRModule PrimFuncPassNode::operator()(IRModule mod,
const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, true);
// Execute the pass function and return a new module.
IRModule updated_mod = IRModule(
mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, PrimFunc> > updates;
for (const auto& it : updated_mod->functions) {
// only picks up relay::PrimFunc
if (auto* n = it.second.as<PrimFuncNode>()) {
PrimFunc func = GetRef<PrimFunc>(n);
auto updated_func =
pass_func(func, updated_mod, pass_ctx);
updates.push_back({it.first, updated_func});
std::vector<ObjectRef> deleted_list;
IRModuleNode* mod_ptr = mod.CopyOnWrite();
auto* func_dict = mod_ptr->functions.CopyOnWrite();
// directly loop over the underlying dict
for (auto& kv : func_dict->data) {
// only picks up tir::PrimFunc
if (kv.second->IsInstance<PrimFuncNode>()) {
// move out the function so that it is the only copy.
PrimFunc func = Downcast<PrimFunc>(std::move(kv.second));
func = pass_func(std::move(func), mod, pass_ctx);
kv.second = std::move(func);
if (!kv.second.defined()) {
deleted_list.push_back(kv.first);
}
}
}
// automatic removal of None
for (const auto& pair : updates) {
if (pair.second.defined()) {
updated_mod->Add(pair.first, pair.second, true);
} else {
updated_mod->Remove(pair.first);
}
for (const auto& gv : deleted_list) {
func_dict->data.erase(gv);
}
pass_ctx.Trace(updated_mod, pass_info, false);
return updated_mod;
pass_ctx.Trace(mod, pass_info, false);
return mod;
}
Pass CreatePrimFuncPass(
......
......@@ -397,17 +397,14 @@ Stmt NarrowDataType(Stmt stmt, int target_bits) {
namespace transform {
Pass NarrowDataType() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
Pass NarrowDataType(int target_bits) {
auto pass_func = [target_bits](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
IntImm target_bits = f->GetAttr<IntImm>("target_bits");
CHECK(target_bits.defined())
<< "NarrowDataType: Require the target_bits";
n->body = DataTypeRewriter(target_bits->value)(std::move(n->body));
n->body = DataTypeRewriter(target_bits)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(
pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {});
pass_func, 0, "tir.NarrowDataType", {});
}
TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType")
......
......@@ -173,7 +173,7 @@ Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
class HostDeviceSplitter : public StmtMutator {
public:
explicit HostDeviceSplitter(IRModuleNode* device_mod,
explicit HostDeviceSplitter(IRModule* device_mod,
Target device_target,
std::string name_prefix)
: device_mod_(device_mod),
......@@ -240,7 +240,7 @@ class HostDeviceSplitter : public StmtMutator {
runtime::String(kernel_symbol));
device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias, Integer(1));
device_func = WithAttr(std::move(device_func), tvm::attr::kTarget, device_target_);
device_mod_->Add(GlobalVar(kernel_symbol), device_func);
(*device_mod_)->Add(GlobalVar(kernel_symbol), device_func);
// generate calls to the device function
Array<PrimExpr> call_args;
......@@ -257,7 +257,7 @@ class HostDeviceSplitter : public StmtMutator {
}
// target ir module
IRModuleNode* device_mod_;
IRModule* device_mod_;
// Device target
Target device_target_;
// function name hint
......@@ -268,7 +268,7 @@ class HostDeviceSplitter : public StmtMutator {
};
PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) {
PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "SplitHostDevice: Require the target attribute";
......@@ -287,26 +287,22 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) {
}
namespace transform {
Pass SplitHostDevice() {
auto pass_func = [](IRModule m, PassContext ctx) {
IRModuleNode* mptr = m.CopyOnWrite();
std::vector<std::pair<GlobalVar, PrimFunc> > updates;
for (const auto& kv : mptr->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
PrimFunc func = GetRef<PrimFunc>(n);
auto updated_func = SplitHostDevice(std::move(func), mptr);
updates.push_back({kv.first, updated_func});
auto pass_func = [](IRModule mod, PassContext ctx) {
IRModuleNode* mod_ptr = mod.CopyOnWrite();
auto* func_dict = mod_ptr->functions.CopyOnWrite();
IRModule device_mod = IRModule::Empty();
for (auto& kv : func_dict->data) {
if (kv.second->IsInstance<PrimFuncNode>()) {
PrimFunc func = Downcast<PrimFunc>(std::move(kv.second));
kv.second = SplitHostDevice(std::move(func), &device_mod);
}
}
for (const auto& pair : updates) {
mptr->Add(pair.first, pair.second, true);
}
return m;
mod->Update(device_mod);
return mod;
};
return tvm::transform::CreateModulePass(
......
......@@ -22,6 +22,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/expr.h>
TEST(PackedFunc, Basic) {
......@@ -274,6 +275,16 @@ TEST(TypedPackedFunc, RValue) {
using namespace tvm;
using namespace tvm::runtime;
{
auto inspect = [](TVMArgs args, TVMRetValue* rv) {
for (int i = 0; i < args.size(); ++i) {
CHECK_EQ(args[0].type_code(), kTVMObjectRValueRefArg);
}
};
PackedFunc finspect(inspect);
finspect(tir::Var("x"));
}
{
auto f = [](tir::Var x, bool move) {
if (move) {
CHECK(x.unique());
......@@ -287,9 +298,9 @@ TEST(TypedPackedFunc, RValue) {
tir::Var var("x");
CHECK(var.unique());
f(var, false);
tf(var, false);
// move the result to the function.
tir::Var ret = f(std::move(var), true);
tir::Var ret = tf(std::move(var), true);
CHECK(!var.defined());
}
......@@ -307,10 +318,10 @@ TEST(TypedPackedFunc, RValue) {
tir::Var var("x");
CHECK(var.unique());
f(var, false);
f(std::move(var), true);
tf(var, false);
tf(std::move(var), true);
// auto conversion.
f(1, true);
tf(1, true);
}
}
......
......@@ -20,9 +20,9 @@ from tvm.tir import const
def lower_stmt(params, stmt, target_bits):
func = tvm.tir.PrimFunc(params, stmt).with_attr(
"target_bits", target_bits)
func = tvm.tir.transform.NarrowDataType()(tvm.IRModule.from_expr(func))["main"]
func = tvm.tir.PrimFunc(params, stmt)
func = tvm.tir.transform.NarrowDataType(target_bits)(
tvm.IRModule.from_expr(func))["main"]
stmt = func.body
return stmt
......
......@@ -46,5 +46,25 @@ def test_prim_func_pass():
assert tvm.ir.structural_equal(mod["main"].body, new_func.body)
def test_cow_pass():
def fapply(f):
assert tvm.testing.object_use_count(f) == 1
return f
pidentity = tvm.tir.transform.Apply(fapply)
x = te.var('x')
func = tvm.tir.PrimFunc(
[x], tvm.tir.Evaluate(x)).with_attr("target_bits", 32)
func_hash = func.__hash__()
mod = tvm.IRModule({"main": func})
del func
# copy on write
mod_hash = mod.__hash__()
mod = tvm.ir.transform.Sequential(
[pidentity, tvm.tir.transform.NarrowDataType(32)])(mod._move())
assert mod_hash == mod.__hash__()
assert func_hash == mod["main"].__hash__()
if __name__ == "__main__":
test_cow_pass()
test_prim_func_pass()
......@@ -42,12 +42,5 @@ inline Array<Integer> ArrayOrInt(TVMArgValue arg) {
return arg;
}
}
inline bool IsTensorType(TVMArgValue arg) {
return (arg.type_code() == kTVMObjectHandle &&
static_cast<Object*>(
arg.value().v_handle)->IsInstance<tvm::te::TensorNode>());
}
} // namespace topi
#endif // TOPI_UTIL_H_
......@@ -35,8 +35,8 @@ using namespace tvm::runtime;
#define TOPI_REGISTER_BCAST_OP(OpName, Op) \
TVM_REGISTER_GLOBAL(OpName) \
.set_body([](TVMArgs args, TVMRetValue *rv) { \
bool lhs_is_tensor = IsTensorType(args[0]); \
bool rhs_is_tensor = IsTensorType(args[1]); \
bool lhs_is_tensor = args[0].IsObjectRef<tvm::te::Tensor>(); \
bool rhs_is_tensor = args[1].IsObjectRef<tvm::te::Tensor>(); \
if (lhs_is_tensor && rhs_is_tensor) { \
*rv = Op(args[0].operator tvm::te::Tensor(), \
args[1].operator tvm::te::Tensor()); \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment