Unverified Commit 09eb5082 by Trevor Morris Committed by GitHub

[BYOC] Prevent duplicate outputs in subgraph Tuple (#5320)

* Fix duplicate output in partitiongraph

* Add test case

* Fix test_annotated_regions with duplicate compiler_end outputs

* Revert "Fix duplicate output in partitiongraph"

This reverts commit e1f8ef3f4ca5b2aaa31ace6fa968bb50e5e4d1fa.

* Prevent duplicate outputs in Tuple in PartitionGraph

* Fix lint

* Add another test case for when regions are merged, and when TupleGetItem was duplicated

* Pull GetFunctionOutput out of branch, improve description of GetFunctionOutput

* Use std::move for GetFunctionOutput. Fix typo with testcase name

* Use tvm.transform.Sequential
parent e8138f7d
...@@ -205,99 +205,13 @@ class Partitioner : public ExprMutator { ...@@ -205,99 +205,13 @@ class Partitioner : public ExprMutator {
// region_function_calls is map that maintains // region_function_calls is map that maintains
// (each annotated regions) --> created function // (each annotated regions) --> created function
if (region_function_calls.find(region) != region_function_calls.end()) { if (region_function_calls.find(region) == region_function_calls.end()) {
// This section is executed only if there are multiple outputs in the // First time this region is encountered in the traversal.
// region Thus, the function is always created and at the end there // Creating the function.
// would be a tuple node Therefore, we insert a tuple get item node. CreateFunction(region, call);
// Use the already created tuple node
auto sg_call = region_function_calls[region];
int index = GetRetIdx(region, GetRef<Call>(call));
CHECK_NE(index, -1);
auto tuple_get_item_ = TupleGetItem(sg_call, index);
tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
return std::move(tuple_get_item_);
} else {
// First time this region is encountered in the traversal
// Creating the function
Array<Expr> fields;
for (auto ret : region->GetOutputs()) {
auto ret_expr = VisitExpr(Downcast<Call>(ret)->args[0]);
fields.push_back(ret_expr);
}
int index = GetRetIdx(region, GetRef<Call>(call));
CHECK_NE(index, -1);
Array<Var> params;
Array<Expr> param_expr;
std::unordered_map<std::string, runtime::NDArray> params_bind;
for (auto pair : region_args[region]) {
params.push_back(pair.first);
if (const auto* cn = pair.second.as<ConstantNode>()) {
params_bind[pair.first->name_hint()] = cn->data;
} else {
param_expr.push_back(pair.second);
}
}
Function global_region_func;
if (region->GetOutputs().size() == 1) {
// If there are only a single output; no need to add a tuple
global_region_func =
Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
} else {
auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
}
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());
global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
runtime::String(name));
global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::runtime::String(target));
global_region_func =
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
// Constant propagation
if (!params_bind.empty()) {
global_region_func = backend::BindParamsByName(global_region_func, params_bind);
}
std::string fname = name;
CHECK(!module_->ContainGlobalVar(fname))
<< "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent
// relay function level passes (i.e. simplify inference and fusion)
// optimizing it.
GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func);
// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = Call(glob_func, param_expr);
region_function_calls[region] = ret;
if (region->GetOutputs().size() == 1) {
// If there is only a single output; no need to add a tuplegetitem
// node
return std::move(ret);
} else {
// Add a tuplegetitem node to select this output out of many
auto tuple_get_item_ = TupleGetItem(ret, index);
tuple_get_item_->checked_type_ = GetRef<Call>(call)->args[0]->checked_type_;
return std::move(tuple_get_item_);
}
} }
// Retrieve this particular output of function.
return GetFunctionOutput(region, GetRef<Call>(call));
} }
} }
...@@ -456,18 +370,111 @@ class Partitioner : public ExprMutator { ...@@ -456,18 +370,111 @@ class Partitioner : public ExprMutator {
} }
/*! /*!
* \brief Get the index of the return(output); * \brief This function is called first time that we encounter a compiler_end
* this is to be used as tuplegetitem idx * node to create the function for the subgraph.
*/ */
int GetRetIdx(AnnotatedRegion sg, const Expr& arg) { void CreateFunction(AnnotatedRegion region, const CallNode* call) {
int idx = 0; // Create fields which is a unique list of outputs. Also populate
for (auto arg_ : sg->GetOutputs()) { // region_return_indices_ map which maps parent of compiler_end node to
if (arg == arg_) { // corresponding index in fields.
return idx; Array<Expr> fields;
int i = 0;
for (auto ret : region->GetOutputs()) {
auto ret_node = Downcast<Call>(ret)->args[0];
// Don't duplicate outputs.
if (!region_return_indices_.count(region) ||
!region_return_indices_[region].count(ret_node)) {
auto ret_expr = VisitExpr(ret_node);
fields.push_back(ret_expr);
region_return_indices_[region][ret_node] = i;
i++;
} }
idx++;
} }
return -1;
Array<Var> params;
Array<Expr> param_expr;
std::unordered_map<std::string, runtime::NDArray> params_bind;
for (auto pair : region_args[region]) {
params.push_back(pair.first);
if (const auto* cn = pair.second.as<ConstantNode>()) {
params_bind[pair.first->name_hint()] = cn->data;
} else {
param_expr.push_back(pair.second);
}
}
Function global_region_func;
if (fields.size() == 1) {
// If there are only a single output; no need to add a tuple
global_region_func =
Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs());
} else {
auto tuple = Tuple(fields);
global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs());
}
std::string target = call->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());
global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol,
runtime::String(name));
global_region_func =
WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler,
tvm::runtime::String(target));
global_region_func =
WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));
// Constant propagation
if (!params_bind.empty()) {
global_region_func = backend::BindParamsByName(global_region_func, params_bind);
}
std::string fname = name;
CHECK(!module_->ContainGlobalVar(fname))
<< "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region.
// This way we lift the functions that should be handled by external
// codegen to the module scope and rely on the pass manager to prevent
// relay function level passes (i.e. simplify inference and fusion)
// optimizing it.
GlobalVar glob_func(fname);
module_->Add(glob_func, global_region_func);
// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = Call(glob_func, param_expr);
region_function_calls[region] = ret;
}
/*!
* \brief Get the return(output) of the function for compiler end node "end_arg".
* This will return either a Call (for a function with a single output) or a
* TupleGetItem (for a function with multiple outputs).
*/
Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) {
Expr arg = Downcast<Call>(end_arg)->args[0];
// Function has one output.
if (region_return_indices_[region].size() == 1) {
return region_function_calls[region];
}
// Function has multiple outputs.
// Use already made TupleGetItem.
if (region_return_tuplegetitem_.count(region) &&
region_return_tuplegetitem_[region].count(arg)) {
return region_return_tuplegetitem_[region][arg];
}
// Create new TupleGetItem.
CHECK(region_return_indices_.count(region) &&
region_return_indices_[region].count(arg));
int index = region_return_indices_[region][arg];
auto func_call = region_function_calls[region];
auto tuple_get_item_ = TupleGetItem(func_call, index);
tuple_get_item_->checked_type_ = arg->checked_type_;
region_return_tuplegetitem_[region][arg] = tuple_get_item_;
return std::move(tuple_get_item_);
} }
/*! /*!
...@@ -486,6 +493,23 @@ class Partitioner : public ExprMutator { ...@@ -486,6 +493,23 @@ class Partitioner : public ExprMutator {
region_args; region_args;
/*! /*!
* \brief This map maintains the index of an output in the subgraph function
* for a given region. If there are multiple entries for a region, then the
* function has a tuple of multiple outputs for its return.
*/
using RegionRetIndexMap = std::unordered_map<Expr, int, ObjectHash, ObjectEqual>;
std::unordered_map<AnnotatedRegion, RegionRetIndexMap, ObjectHash, ObjectEqual>
region_return_indices_;
/*!
* \brief This map holds already created TupleGetItem nodes for accessing
* outputs of a function.
*/
using RegionRetTupleGetItemMap = std::unordered_map<Expr, TupleGetItem, ObjectHash, ObjectEqual>;
std::unordered_map<AnnotatedRegion, RegionRetTupleGetItemMap, ObjectHash, ObjectEqual>
region_return_tuplegetitem_;
/*!
* \brief Each region set is associated with a function in the module. * \brief Each region set is associated with a function in the module.
* This map maintains the mapping between regionsets and the function it * This map maintains the mapping between regionsets and the function it
* belongs to * belongs to
......
...@@ -23,6 +23,7 @@ import pytest ...@@ -23,6 +23,7 @@ import pytest
import tvm import tvm
import tvm.relay.testing import tvm.relay.testing
import tvm.relay.op as reg
from tvm import relay from tvm import relay
from tvm import runtime from tvm import runtime
from tvm.relay import transform from tvm.relay import transform
...@@ -1036,6 +1037,138 @@ def test_multiple_use_of_an_output(): ...@@ -1036,6 +1037,138 @@ def test_multiple_use_of_an_output():
test_same_output_region() test_same_output_region()
test_different_output_region() test_different_output_region()
def test_duplicate_outputs():
target = "test_duplicate_outputs"
@reg.register("abs", "target." + target)
def abs(attrs, args): # pylint: disable=unused-variable
return True
def create_graph():
data = relay.var('data', shape=(10, 10))
x = relay.abs(data)
out_1 = relay.nn.relu(x)
out_2 = relay.tanh(x)
out_3 = relay.log(x)
out = relay.Tuple([out_1, out_2, out_3])
func = relay.Function([data], out)
return func
def expected():
mod = tvm.IRModule()
# function 0
f0_i0 = relay.var(target+"_0_i0", shape=(10, 10))
f0_o0 = relay.abs(f0_i0)
func0 = relay.Function([f0_i0], f0_o0)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", target)
func0 = func0.with_attr("global_symbol", target+"_0")
gv0 = relay.GlobalVar(target+"_0")
mod[gv0] = func0
# body
data = relay.var('data', shape=(10, 10))
function_out = gv0(data)
out_1 = relay.nn.relu(function_out)
out_2 = relay.tanh(function_out)
out_3 = relay.log(function_out)
out = relay.Tuple([out_1, out_2, out_3])
func = relay.Function([data], out)
mod["main"] = func
return mod
mod = tvm.IRModule()
mod["main"] = create_graph()
seq = tvm.transform.Sequential([
transform.AnnotateTarget(target),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
])
ref_mod = expected()
partitioned = seq(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_duplicate_merge_and_tuplegetitem():
target = "test_duplicate_merge_and_tuplegetitem"
@reg.register("nn.batch_norm", "target." + target)
def abs(attrs, args): # pylint: disable=unused-variable
return True
@reg.register("nn.relu", "target." + target)
def abs(attrs, args): # pylint: disable=unused-variable
return True
def create_graph():
data = relay.var('data', shape=(10, 10))
bn_gamma = relay.var("bn_gamma")
bn_beta = relay.var("bn_beta")
bn_mmean = relay.var("bn_mean")
bn_mvar = relay.var("bn_var")
x = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)
out_1 = relay.nn.relu(x[0])
bn_out_1 = x[1]
out_2 = relay.tanh(bn_out_1)
out_3 = relay.log(bn_out_1)
out = relay.Tuple([out_1, out_2, out_3])
func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out)
return func
def expected():
mod = tvm.IRModule()
# function 0
f0_i0 = relay.var(target+"_1_i0", shape=(10, 10))
f0_i1 = relay.var(target+"_1_i1")
f0_i2 = relay.var(target+"_1_i2")
f0_i3 = relay.var(target+"_1_i3")
f0_i4 = relay.var(target+"_1_i4")
f0_n0 = relay.nn.batch_norm(f0_i0, f0_i1, f0_i2, f0_i3, f0_i4)
f0_n1 = f0_n0[1]
f0_n2 = relay.nn.relu(f0_n0[0])
f0_o0 = relay.Tuple([f0_n1, f0_n2])
func0 = relay.Function([f0_i0, f0_i1, f0_i2, f0_i3, f0_i4], f0_o0)
func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", target)
func0 = func0.with_attr("global_symbol", target+"_1")
gv0 = relay.GlobalVar(target+"_1")
mod[gv0] = func0
# body
data = relay.var('data', shape=(10, 10))
bn_gamma = relay.var("bn_gamma")
bn_beta = relay.var("bn_beta")
bn_mmean = relay.var("bn_mean")
bn_mvar = relay.var("bn_var")
function_out = gv0(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)
get_out0 = relay.TupleGetItem(function_out, 0)
get_out1 = relay.TupleGetItem(function_out, 1)
out_2 = relay.tanh(get_out0)
out_3 = relay.log(get_out0)
out = relay.Tuple([get_out1, out_2, out_3])
func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out)
mod["main"] = func
return mod
mod = tvm.IRModule()
mod["main"] = create_graph()
seq = tvm.transform.Sequential([
transform.AnnotateTarget(target),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
])
ref_mod = expected()
partitioned = seq(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
if __name__ == "__main__": if __name__ == "__main__":
test_multi_node_compiler() test_multi_node_compiler()
...@@ -1051,3 +1184,5 @@ if __name__ == "__main__": ...@@ -1051,3 +1184,5 @@ if __name__ == "__main__":
test_mixed_single_multiple_outputs() test_mixed_single_multiple_outputs()
test_dnnl_fuse() test_dnnl_fuse()
test_multiple_use_of_an_output() test_multiple_use_of_an_output()
test_duplicate_outputs()
test_duplicate_merge_and_tuplegetitem()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment