Unverified Commit 98b17590 by Zhi Committed by GitHub

[Relay] Target annotation for external codegen (#4933)

* op based external compiler annotation

* Use TVM register directly

* Small fix

* test graph

Co-authored-by: Cody Yu <comaniac0422@gmail.com>
parent 09ddc3eb
......@@ -180,6 +180,20 @@ using FTVMLegalize = runtime::TypedPackedFunc<
const Array<tvm::relay::Type>& arg_types)>;
/*!
* \brief Annotates an expression to indicate if an op should be compiled using
* the given compiler/target.
*
* \param attrs The attribute of the original expr.
* \param args The arguments of the original expr.
*
* \return true if this op should be registered to invoke a specific compiler
* for codegen, otherwise, false.
*/
using FTVMAnnotateTarget = runtime::TypedPackedFunc<
bool(const Attrs& attrs, // NOLINT(*)
const Array<Expr>& args)>;
/*!
* \brief Forward rewriting rule for a specific op.
*
* \param ref_call The reference old call type to be rewritten.
......
......@@ -19,7 +19,7 @@
# operator defs
from .op import get, register, register_compute, register_gradient, \
register_pattern, register_alter_op_layout, register_legalize, \
Op, OpPattern, OpStrategy, debug
Op, OpPattern, OpStrategy, debug, register_external_compiler
from . import strategy
# Operators
......
......@@ -15,5 +15,5 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Neural network related operators."""
from __future__ import absolute_import as _abs
"""Contrib modules."""
from .dnnl import *
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=invalid-name, too-many-lines
"""Contrib operations."""
from __future__ import absolute_import as _abs
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""DNNL library supported operators.
There are two ways to registering a function for an op to indicate if it is
supported by DNNL.
- The first and simplest way is to use the helper so that
users only need to provide the operator name and a boolean value to indicate if
it is supported. For example:
.. code-block:: python
add = _register_external_op_helper("add")
add = _register_external_op_helper("add", True)
add = _register_external_op_helper("add", False)
- The other way is to implement the function by themselves to
check the attributes of the op and decide if it should be offloaded to DNNL.
"""
from ... import op as reg
def _register_external_op_helper(op_name, supported=True):
"""The helper function to indicate that a given operator can be supported
by DNNL.
Paramters
---------
op_name : Str
The name of operator that will be registered.
Returns
-------
f : callable
A function that returns if the operator is supported by DNNL.
"""
@reg.register(op_name, "target.dnnl")
def _func_wrapper(attrs, args):
return supported
return _func_wrapper
_register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu")
_register_external_op_helper("add")
_register_external_op_helper("subtract")
_register_external_op_helper("multiply")
@reg.register("nn.batch_norm", "target.dnnl")
def batch_norm(attrs, args):
"""Check if the external DNNL codegen should be used.
FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs.
"""
return False
......@@ -453,14 +453,36 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
return register(op_name, "FShapeFunc", shape_func, level)
def register_external_compiler(op_name, fexternal=None, level=10):
"""Register the external compiler for an op.
Parameters
----------
op_name : str
The name of the operator.
fexternal : function (attrs: Attrs, args: List[Expr], compiler: str)
-> new_expr: Expr
The function for wrapping a call expr with compiler_begin and
compiler_end.
level : int
The priority level
"""
return register(op_name, "FTVMExternalCompiler", fexternal, level)
@tvm._ffi.register_func("relay.op.compiler._lower")
def _lower(name, schedule, inputs, outputs):
return lower(schedule, list(inputs) + list(outputs), name=name)
@tvm._ffi.register_func("relay.op.compiler._build")
def _build(lowered_funcs):
return build(lowered_funcs, target="llvm")
_schedule_injective = None
_schedule_reduce = None
......
......@@ -552,6 +552,25 @@ def PartitionGraph():
return _transform.PartitionGraph()
def AnnotateTarget(target):
"""Annotate ops in an experession with a provied compiler/target and then
use it for codegen.
Parameters
----------
target : String
The target compiler used for codegen.
Returns
-------
ret : tvm.relay.Pass
The annotated pass that wrapps ops with subgraph_start and
subgraph_end.
"""
return _transform.AnnotateTarget(target)
def Inline():
"""Perform inlining on the given Relay IR module. The global functions that
are marked as `inline` should be always inlined. A cost model will be
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/pass/annotate_target.cc
* \brief Wraps a call with compiler_begin and compiler_end to indicate that
* the op of this call node will use external compiler.
*/
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
namespace tvm {
namespace relay {
namespace annotate_target {
// A helper class to insert annotation boundaries for a program region that will
// be handled by a specific compiler.
class AnnotateTargetWrapper : public ExprMutator {
public:
explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {}
Expr VisitExpr_(const CallNode* cn) {
// TODO(@zhiics, @comaniac) Handle composite functions.
auto new_e = ExprMutator::VisitExpr_(cn);
Call call = Downcast<Call>(new_e);
static auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
Op op = Downcast<Op>(call->op);
CHECK(op.defined());
if (fannotate.count(op)) {
bool external = fannotate[op](call->attrs, call->args);
if (external) {
tvm::Array<tvm::relay::Expr> compiler_begins;
for (const auto& it : call->args) {
const auto* begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
CHECK(begin_op);
Expr begin = (*begin_op)(it, target_);
compiler_begins.push_back(begin);
}
Expr update_call = CallNode::make(call->op, compiler_begins, call->attrs);
const auto* end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
CHECK(end_op);
Expr end = (*end_op)(update_call, target_);
return end;
}
} else {
LOG(WARNING) << op->name << " in " << target_
<< " is not registered. It will be executed on CPU.";
}
return new_e;
}
private:
std::string target_;
};
Expr AnnotateTarget(const Expr& expr, const std::string& target) {
return AnnotateTargetWrapper(target).Mutate(expr);
}
} // namespace annotate_target
namespace transform {
Pass AnnotateTarget(const std::string& target) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(relay::annotate_target::AnnotateTarget(f, target));
};
auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc",
{tir::StringImmNode::make("InferType")});
return transform::Sequential({func_pass, InferType()}, "AnnotateTarget");
}
TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget")
.set_body_typed(AnnotateTarget);
} // namespace transform
} // namespace relay
} // namespace tvm
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for annotating external targets."""
import os
import sys
import numpy as np
import pytest
import tvm
import tvm.relay.testing
import tvm.relay.transform as transform
from tvm import relay
from tvm import runtime
from tvm.contrib import util
def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
ctx=tvm.cpu(), params=None):
if sys.platform == "win32":
print("Skip test on Windows for now")
return
def update_lib(lib):
test_dir = os.path.dirname(
os.path.realpath(os.path.expanduser(__file__)))
source_dir = os.path.join(test_dir, "..", "..", "..")
contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
kwargs = {}
kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path]
tmp_path = util.tempdir()
lib_name = 'lib.so'
lib_path = tmp_path.relpath(lib_name)
lib.export_library(lib_path, fcompile=False, **kwargs)
lib = runtime.load_module(lib_path)
return lib
def check_vm_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
exe = relay.vm.compile(mod, target=target, params=params)
code, lib = exe.save()
lib = update_lib(lib)
exe = runtime.vm.Executable.load_exec(code, lib)
vm = runtime.vm.VirtualMachine(exe)
vm.init(ctx)
out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
def check_graph_runtime_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
json, lib, param = relay.build(mod, target=target, params=params)
lib = update_lib(lib)
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
for name, data in map_inputs.items():
rt_mod.set_input(name, data)
rt_mod.set_input(**param)
rt_mod.run()
out = tvm.nd.empty(out_shape, ctx=ctx)
out = rt_mod.get_output(0, out)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
check_vm_result()
check_graph_runtime_result()
def test_extern_dnnl():
def annotated(dtype, ishape, w1shape):
data = relay.var('data', shape=(ishape), dtype=dtype)
weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype)
depthwise_conv2d_1 = relay.nn.conv2d(data,
weight1,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
weight1,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)
f = relay.Function([data, weight1], out)
mod = tvm.IRModule.from_expr(f)
return mod
def expected(dtype, ishape, w1shape):
data = relay.var('data', shape=(ishape), dtype=dtype)
weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype)
begin0 = relay.annotation.compiler_begin(data, "dnnl")
begin1 = relay.annotation.compiler_begin(weight1, "dnnl")
depthwise_conv2d_1 = relay.nn.conv2d(begin0,
begin1,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
end0 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl")
begin2 = relay.annotation.compiler_begin(end0, "dnnl")
begin3 = relay.annotation.compiler_begin(end0, "dnnl")
begin4 = relay.annotation.compiler_begin(weight1, "dnnl")
depthwise_conv2d_2 = relay.nn.conv2d(begin3,
begin4,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
end1 = relay.annotation.compiler_end(depthwise_conv2d_2, "dnnl")
begin5 = relay.annotation.compiler_begin(end1, "dnnl")
out = relay.add(begin2, begin5)
end2 = relay.annotation.compiler_end(out, "dnnl")
f = relay.Function([data, weight1], end2)
mod = tvm.IRModule.from_expr(f)
return mod
dtype = "float32"
ishape = (1, 32, 14, 14)
w1shape = (32, 1, 3, 3)
def test_annotate():
mod = annotated(dtype, ishape, w1shape)
mod = transform.AnnotateTarget("dnnl")(mod)
ref_mod = expected(dtype, ishape, w1shape)
assert relay.analysis.alpha_equal(mod, ref_mod)
def test_run():
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
return
ref_mod = annotated(dtype, ishape, w1shape)
mod = annotated(dtype, ishape, w1shape)
mod = transform.PartitionGraph()(mod)
i_data = np.random.uniform(0, 1, ishape).astype(dtype)
w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu())
ref_res = ref_ex.evaluate()(i_data, w1_data)
check_result(mod, {"data": i_data, "weight1": w1_data},
(1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
test_annotate()
test_run()
def test_extern_dnnl_mobilenet():
if not tvm.get_global_func("relay.ext.dnnl", True):
print("skip because DNNL codegen is not available")
return
dtype = 'float32'
ishape = (1, 3, 224, 224)
mod, params = relay.testing.mobilenet.get_workload(
batch_size=1, dtype='float32')
mod = transform.AnnotateTarget("dnnl")(mod)
mod = transform.PartitionGraph()(mod)
i_data = np.random.uniform(0, 1, ishape).astype(dtype)
ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1,
dtype='float32')
ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0))
ref_res = ref_ex.evaluate()(i_data, **params)
check_result(mod, {"data": i_data},
(1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
if __name__ == "__main__":
test_extern_dnnl()
test_extern_dnnl_mobilenet()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment