Unverified Commit 046b0d98 by mbaret Committed by GitHub

[BYOC] Bind constant tuples in graph partitioner (#5476)

* Bind constant tuples in the graph partitioner

Change-Id: I815b32b5445a536c1837369b04f67dbbb0aed900

* Add partitioning test

Change-Id: I3a492ec8d1beab4830214e3bc8da2a7c80771ca4

* Rename test target

Change-Id: Ie32f37c1395ff597c0047ad3a93ed04ce3f3125d
parent 4d8816cc
......@@ -393,12 +393,26 @@ class Partitioner : public ExprMutator {
Array<Var> params;
Array<Expr> param_expr;
std::unordered_map<std::string, runtime::NDArray> params_bind;
Map<Var, Expr> params_bind;
auto IsConstant = [](const Expr& expr) {
if (expr->IsInstance<ConstantNode>())
return true;
if (expr->IsInstance<TupleNode>()) {
auto tuple = expr.as<TupleNode>();
for (const auto& field : tuple->fields) {
if (!field->IsInstance<ConstantNode>())
return false;
}
return true;
}
return false;
};
for (auto pair : region_args[region]) {
params.push_back(pair.first);
if (const auto* cn = pair.second.as<ConstantNode>()) {
params_bind[pair.first->name_hint()] = cn->data;
if (IsConstant(pair.second)) {
params_bind.Set(pair.first, pair.second);
} else {
param_expr.push_back(pair.second);
}
......@@ -428,7 +442,7 @@ class Partitioner : public ExprMutator {
// Constant propagation
if (!params_bind.empty()) {
global_region_func = backend::BindParamsByName(global_region_func, params_bind);
global_region_func = Downcast<Function>(relay::Bind(global_region_func, params_bind));
}
std::string fname = name;
......
......@@ -1155,6 +1155,42 @@ def test_duplicate_merge_and_tuplegetitem():
partitioned = seq(mod)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)
def test_constant_tuples():
@reg.register("qnn.concatenate", "target.const_tuples")
def add(attrs, args): # pylint: disable=unused-variable
return True
def create_graph():
a = relay.var('a', shape=(10, 10), dtype="uint8")
b = relay.var('b', shape=(10, 10), dtype="uint8")
a1 = relay.abs(a)
zeroi = relay.const(1, "int32")
zerof = relay.const(0, "float32")
con = relay.qnn.op.concatenate((a1, b),
input_scales=(zerof, zerof),
input_zero_points=(zeroi, zeroi),
output_scale=zerof,
output_zero_point=zeroi,
axis=1)
f = relay.Function([a, b], con)
mod = tvm.IRModule.from_expr(f)
return mod
seq = tvm.transform.Sequential([
transform.AnnotateTarget("const_tuples"),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
])
partitioned = seq(create_graph())
concat = partitioned["const_tuples_0"].body
assert type(concat.args[1]) == relay.Tuple
assert type(concat.args[2]) == relay.Tuple
assert type(concat.args[3]) == relay.Constant
assert type(concat.args[4]) == relay.Constant
if __name__ == "__main__":
test_multi_node_compiler()
test_extern_ccompiler_single_op()
......@@ -1171,3 +1207,4 @@ if __name__ == "__main__":
test_multiple_use_of_an_output()
test_duplicate_outputs()
test_duplicate_merge_and_tuplegetitem()
test_constant_tuples()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment