Unverified Commit 09eb5082 by Trevor Morris Committed by GitHub

[BYOC] Prevent duplicate outputs in subgraph Tuple (#5320)

* Fix duplicate output in partitiongraph

* Add test case

* Fix test_annotated_regions with duplicate compiler_end outputs

* Revert "Fix duplicate output in partitiongraph"

This reverts commit e1f8ef3f4ca5b2aaa31ace6fa968bb50e5e4d1fa.

* Prevent duplicate outputs in Tuple in PartitionGraph

* Fix lint

* Add another test case for when regions are merged, and when TupleGetItem was duplicated

* Pull GetFunctionOutput out of branch, improve description of GetFunctionOutput

* Use std::move for GetFunctionOutput. Fix typo with testcase name

* Use tvm.transform.Sequential
parent e8138f7d
......@@ -205,99 +205,13 @@ class Partitioner : public ExprMutator {
// region_function_calls is map that maintains
// (each annotated regions) --> created function
if (region_function_calls.find(region) != region_function_calls.end()) {
// This section is executed only if there are multiple outputs in the
// region Thus, the function is always created and at the end there
// would be a tuple node Therefore, we insert a tuple get item node.
// Use the already created tuple node
auto sg_call = region_function_calls[region];
int index = GetRetIdx(region, GetRef<Call>(call));
CHECK_NE(index, -1);
auto tuple_get_item_ = TupleGetItem(sg_call, index);
tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
return std::move(tuple_get_item_);
} else {
// First time this region is encountered in the traversal
// Creating the function
Array<Expr> fields;
for (auto ret : region->GetOutputs()) {
auto ret_expr = VisitExpr(Downcast<Call>(ret)->args[0]);
fields.push_back(ret_expr);
}
int index = GetRetIdx(region, GetRef<Call>(call));
CHECK_NE(index, -1);
Array<Var> params;
Array<Expr> param_expr;
std::unordered_map<std::string, runtime::NDArray> params_bind;
for (auto pair : region_args[region]) {
params.push_back(pair.first);
if (const auto* cn = pair.second.as<ConstantNode>()) {
params_bind[pair.first->name_hint()] = cn->data;
} else {
param_expr.push_back(pair.second);
}
}
Function global_region_func;
if (region->GetOutputs().size() == 1) {
// If there are only a single output; no need to add a tuple
global_region_func =
Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
} else {
auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
}
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());
global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
runtime::String(name));
global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::runtime::String(target));
global_region_func =
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
// Constant propagation
if (!params_bind.empty()) {
global_region_func = backend::BindParamsByName(global_region_func, params_bind);
}
std::string fname = name;
CHECK(!module_->ContainGlobalVar(fname))
<< "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent
// relay function level passes (i.e. simplify inference and fusion)
// optimizing it.
GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func);
// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = Call(glob_func, param_expr);
region_function_calls[region] = ret;
if (region->GetOutputs().size() == 1) {
// If there is only a single output; no need to add a tuplegetitem
// node
return std::move(ret);
} else {
// Add a tuplegetitem node to select this output out of many
auto tuple_get_item_ = TupleGetItem(ret, index);
tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
return std::move(tuple_get_item_);
}
if (region_function_calls.find(region) == region_function_calls.end()) {
// First time this region is encountered in the traversal.
// Creating the function.
CreateFunction(region, call);
}
// Retrieve this particular output of function.
return GetFunctionOutput(region, GetRef<Call>(call));
}
}
......@@ -456,18 +370,111 @@ class Partitioner : public ExprMutator {
}
/*!
* \brief Get the index of the return(output);
* this is to be used as tuplegetitem idx
* \brief This function is called first time that we encounter a compiler_end
* node to create the function for the subgraph.
*/
int GetRetIdx(AnnotatedRegion sg, const Expr& arg) {
int idx = 0;
for (auto arg_ : sg->GetOutputs()) {
if (arg == arg_) {
return idx;
void CreateFunction(AnnotatedRegion region, const CallNode* call) {
// Create fields which is a unique list of outputs. Also populate
// region_return_indices_ map which maps parent of compiler_end node to
// corresponding index in fields.
Array<Expr> fields;
int i = 0;
for (auto ret : region->GetOutputs()) {
auto ret_node = Downcast<Call>(ret)->args[0];
// Don't duplicate outputs.
if (!region_return_indices_.count(region) ||
!region_return_indices_[region].count(ret_node)) {
auto ret_expr = VisitExpr(ret_node);
fields.push_back(ret_expr);
region_return_indices_[region][ret_node] = i;
i++;
}
idx++;
}
return -1;
Array<Var> params;
Array<Expr> param_expr;
std::unordered_map<std::string, runtime::NDArray> params_bind;
for (auto pair : region_args[region]) {
params.push_back(pair.first);
if (const auto* cn = pair.second.as<ConstantNode>()) {
params_bind[pair.first->name_hint()] = cn->data;
} else {
param_expr.push_back(pair.second);
}
}
Function global_region_func;
if (fields.size() == 1) {
// If there are only a single output; no need to add a tuple
global_region_func =
Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
} else {
auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
}
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());
global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
runtime::String(name));
global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::runtime::String(target));
global_region_func =
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
// Constant propagation
if (!params_bind.empty()) {
global_region_func = backend::BindParamsByName(global_region_func, params_bind);
}
std::string fname = name;
CHECK(!module_->ContainGlobalVar(fname))
<< "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent
// relay function level passes (i.e. simplify inference and fusion)
// optimizing it.
GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func);
// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = Call(glob_func, param_expr);
region_function_calls[region] = ret;
}
/*!
* \brief Get the return(output) of the function for compiler end node "end_arg".
* This will return either a Call (for a function with a single output) or a
* TupleGetItem (for a function with multiple outputs).
*/
Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) {
Expr arg = Downcast<Call>(end_arg)->args[0];
// Function has one output.
if (region_return_indices_[region].size() == 1) {
return region_function_calls[region];
}
// Function has multiple outputs.
// Use already made TupleGetItem.
if (region_return_tuplegetitem_.count(region) &&
region_return_tuplegetitem_[region].count(arg)) {
return region_return_tuplegetitem_[region][arg];
}
// Create new TupleGetItem.
CHECK(region_return_indices_.count(region) &&
region_return_indices_[region].count(arg));
int index = region_return_indices_[region][arg];
auto func_call = region_function_calls[region];
auto tuple_get_item_ = TupleGetItem(func_call, index);
tuple_get_item_->checked_type_ = arg->checked_type_;
region_return_tuplegetitem_[region][arg] = tuple_get_item_;
return std::move(tuple_get_item_);
}
/*!
......@@ -486,6 +493,23 @@ class Partitioner : public ExprMutator {
region_args;
/*!
* \brief This map maintains the index of an output in the subgraph function
* for a given region. If there are multiple entries for a region, then the
* function has a tuple of multiple outputs for its return.
*/
using RegionRetIndexMap = std::unordered_map<Expr, int, ObjectHash, ObjectEqual>;
std::unordered_map<AnnotatedRegion, RegionRetIndexMap, ObjectHash, ObjectEqual>
region_return_indices_;
/*!
* \brief This map holds already created TupleGetItem nodes for accessing
* outputs of a function.
*/
using RegionRetTupleGetItemMap = std::unordered_map<Expr, TupleGetItem, ObjectHash, ObjectEqual>;
std::unordered_map<AnnotatedRegion, RegionRetTupleGetItemMap, ObjectHash, ObjectEqual>
region_return_tuplegetitem_;
/*!
* \brief Each region set is associated with a function in the module.
* This map maintains the mapping between regionsets and the function it
* belongs to
......
......@@ -23,6 +23,7 @@ import pytest
import tvm
import tvm.relay.testing
import tvm.relay.op as reg
from tvm import relay
from tvm import runtime
from tvm.relay import transform
......@@ -1036,6 +1037,138 @@ def test_multiple_use_of_an_output():
test_same_output_region()
test_different_output_region()
def test_duplicate_outputs():
target = "test_duplicate_outputs"
@reg.register("abs", "target." + target)
def abs(attrs, args): # pylint: disable=unused-variable
return True
def create_graph():
data = relay.var('data', shape=(10, 10))
x = relay.abs(data)
out_1 = relay.nn.relu(x)
out_2 = relay.tanh(x)
out_3 = relay.log(x)
out = relay.Tuple([out_1, out_2, out_3])
func = relay.Function([data], out)
return func
def expected():
mod = tvm.IRModule()
# function 0
f0_i0 = relay.var(target+"_0_i0", shape=(10, 10))
f0_o0 = relay.abs(f0_i0)
func0 = relay.Function([f0_i0], f0_o0)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", target)
func0 = func0.with_attr("global_symbol", target+"_0")
gv0 = relay.GlobalVar(target+"_0")
mod[gv0] = func0
# body
data = relay.var('data', shape=(10, 10))
function_out = gv0(data)
out_1 = relay.nn.relu(function_out)
out_2 = relay.tanh(function_out)
out_3 = relay.log(function_out)
out = relay.Tuple([out_1, out_2, out_3])
func = relay.Function([data], out)
mod["main"] = func
return mod
mod = tvm.IRModule()
mod["main"] = create_graph()
seq = tvm.transform.Sequential([
transform.AnnotateTarget(target),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
])
ref_mod = expected()
partitioned = seq(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_duplicate_merge_and_tuplegetitem():
target = "test_duplicate_merge_and_tuplegetitem"
@reg.register("nn.batch_norm", "target." + target)
def abs(attrs, args): # pylint: disable=unused-variable
return True
@reg.register("nn.relu", "target." + target)
def abs(attrs, args): # pylint: disable=unused-variable
return True
def create_graph():
data = relay.var('data', shape=(10, 10))
bn_gamma = relay.var("bn_gamma")
bn_beta = relay.var("bn_beta")
bn_mmean = relay.var("bn_mean")
bn_mvar = relay.var("bn_var")
x = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)
out_1 = relay.nn.relu(x[0])
bn_out_1 = x[1]
out_2 = relay.tanh(bn_out_1)
out_3 = relay.log(bn_out_1)
out = relay.Tuple([out_1, out_2, out_3])
func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out)
return func
def expected():
mod = tvm.IRModule()
# function 0
f0_i0 = relay.var(target+"_1_i0", shape=(10, 10))
f0_i1 = relay.var(target+"_1_i1")
f0_i2 = relay.var(target+"_1_i2")
f0_i3 = relay.var(target+"_1_i3")
f0_i4 = relay.var(target+"_1_i4")
f0_n0 = relay.nn.batch_norm(f0_i0, f0_i1, f0_i2, f0_i3, f0_i4)
f0_n1 = f0_n0[1]
f0_n2 = relay.nn.relu(f0_n0[0])
f0_o0 = relay.Tuple([f0_n1, f0_n2])
func0 = relay.Function([f0_i0, f0_i1, f0_i2, f0_i3, f0_i4], f0_o0)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", target)
func0 = func0.with_attr("global_symbol", target+"_1")
gv0 = relay.GlobalVar(target+"_1")
mod[gv0] = func0
# body
data = relay.var('data', shape=(10, 10))
bn_gamma = relay.var("bn_gamma")
bn_beta = relay.var("bn_beta")
bn_mmean = relay.var("bn_mean")
bn_mvar = relay.var("bn_var")
function_out = gv0(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)
get_out0 = relay.TupleGetItem(function_out, 0)
get_out1 = relay.TupleGetItem(function_out, 1)
out_2 = relay.tanh(get_out0)
out_3 = relay.log(get_out0)
out = relay.Tuple([get_out1, out_2, out_3])
func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out)
mod["main"] = func
return mod
mod = tvm.IRModule()
mod["main"] = create_graph()
seq = tvm.transform.Sequential([
transform.AnnotateTarget(target),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
])
ref_mod = expected()
partitioned = seq(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
if __name__ == "__main__":
test_multi_node_compiler()
......@@ -1051,3 +1184,5 @@ if __name__ == "__main__":
test_mixed_single_multiple_outputs()
test_dnnl_fuse()
test_multiple_use_of_an_output()
test_duplicate_outputs()
test_duplicate_merge_and_tuplegetitem()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment