Commit 156aa590 by Josh Fromm Committed by Jared Roesch

[Relay][Frontend][ONNX] New Operators and Opsets to Support BERT (#4197)

* Added slice v10

* Added constantofshape operation and small refactor.

* Finished one_hot implementation.

* Reshape working across all bert layers.

* Fixed constantofshape and removed code duplication.

* onnx model fully ingested.

* Working on improving onnx tests.

* Changed onnx testing to use onnxruntime instead of caffe2, also formatted.

* Add arbitrary output nodes to onnx frontend.

* Added v6 tiling for bert squad 8 support.

* Small syntax fixes

* Reduced code duplication in split opset versions.

* Added batch matmul test

* Added unstack split testing.

* Adde onehot test, needs a little cleanup probably.

* Replaced deprecated constant fill with constantofshape and updated tests accordingly.

* Added tests for new opset version of slice and tile.

* lint clean up

* Lint fixes

* Changed onnx dependency

* Went back to caffe2 runtime for CI integration.

* Rebase and small typo/syntax changes.

* Added hard casting of onehot attributes to int.
parent 71f39be5
......@@ -19,11 +19,13 @@ from __future__ import absolute_import as _abs
import logging
import tvm
import numpy as np
from topi.util import get_const_tuple
from .. import expr as _expr
from .. import module as _module
from .. import transform as _transform
from .. import op as _op
from .. import analysis
class RequiredAttr(object):
......@@ -474,6 +476,50 @@ def infer_channels(inputs, transpose=False):
return channels
def infer_value(input_val, params):
"""A hack for getting the value of an expression by evaluating a
portion of the relay graph. This is often needed for functions that
whose output shape depends on the value of a tensor.
"""
from tvm.contrib import graph_runtime
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
func = _expr.Function(analysis.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
return m.get_output(0)
def infer_value_simulated(input_val, params):
"""Extention to infer_value that can be used when some input
values are missing. This function creates dummy inputs with the same
shape and random values then calls infer_value. This is helpful when
implementing certain onnx operators where we need to evaluate the graph
to determine a static shape.
"""
fake_params = []
# Add a fake copy of all missing params.
for free_param in analysis.free_vars(input_val):
if free_param.name_hint not in params:
fp_dtype = free_param.type_annotation.dtype
fp_shape = [s.value for s in free_param.type_annotation.shape]
fake_params.append(free_param)
params[free_param.name_hint] = tvm.nd.array(
np.random.rand(*fp_shape).astype(fp_dtype)
)
# Now infer the value.
output_value = infer_value(input_val, params)
# Clean fake params out of param dictionary.
for fake_p in fake_params:
params.pop(fake_p.name_hint, None)
return output_value
def new_var(name_hint,
type_annotation=None,
shape=None,
......
......@@ -39,22 +39,10 @@ from .common import AttrCvt, get_relay_op
from .common import infer_type as _infer_type
from .common import infer_shape as _infer_shape
from .common import infer_channels as _infer_channels
from .common import infer_value as _infer_value
__all__ = ['from_tensorflow']
def _infer_value(input_val, params):
from tvm.contrib import graph_runtime
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
func = _expr.Function(analysis.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.context("llvm", 0)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
return m.get_output(0)
def _get_pad_pair(input1d, kernel1d, stride1d):
if input1d % stride1d == 0:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment