Commit df7cc5db by lixiaoquan Committed by Zhi

[TENSORFLOW] Convert scalar Const into tvm.relay.const (#3885)

* [TENSORFLOW] Convert scalar Const into tvm.relay.const

* use _get_num_param() and _get_list_param()
parent 5ed251a6
......@@ -84,10 +84,12 @@ def _dimension_constraint():
return _dim_check, "Only 2d kernel supported."
def _get_param(params, input_node):
if isinstance(input_node, _expr.Constant):
return np.atleast_1d(input_node.data.asnumpy())
return params.pop(input_node.name_hint).asnumpy()
def _get_num_param(params, input_node):
return _get_param(params, input_node)[0]
return _get_param(params, input_node).item()
def _get_list_param(params, input_node):
return _get_param(params, input_node).tolist()
......@@ -335,9 +337,9 @@ def _crop_and_resize():
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
# boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2]
try:
boxes = params.pop(inputs[1].name_hint).asnumpy().tolist()
box_ind = params.pop(inputs[2].name_hint).asnumpy().tolist()
crop_size = params.pop(inputs[3].name_hint).asnumpy().tolist()
boxes = _get_list_param(params, inputs[1])
box_ind = _get_list_param(params, inputs[2])
crop_size = _get_list_param(params, inputs[3])
except (IndexError, KeyError):
boxes = _infer_value(inputs[1], params).asnumpy().tolist()
box_ind = _infer_value(inputs[2], params).asnumpy().tolist()
......@@ -505,7 +507,7 @@ def _pack():
def _tile():
def _impl(inputs, attr, params):
reps = params[inputs.pop().name_hint].asnumpy()
reps = _get_list_param(params, inputs.pop())
new_input = []
new_input.append(inputs.pop(0))
......@@ -752,7 +754,7 @@ def _sum():
def _reduce(op):
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy()
axis = _get_list_param(params, inputs[1])
axis = tuple(axis)
return AttrCvt(
op_name=op,
......@@ -937,8 +939,8 @@ def _where():
def _clip_by_value():
def _impl(inputs, attr, params):
a_min = params.pop(inputs[1].name_hint).asnumpy()[0]
a_max = params.pop(inputs[2].name_hint).asnumpy()[0]
a_min = _get_num_param(params, inputs[1])
a_max = _get_num_param(params, inputs[2])
return _op.clip(inputs[0], a_min=a_min, a_max=a_max)
return _impl
......@@ -965,10 +967,11 @@ def _rank():
def _range():
def _impl(inputs, attr, params):
start = params.pop(inputs[0].name_hint).asnumpy()[0]
limit = params.pop(inputs[1].name_hint).asnumpy()[0] \
if hasattr(inputs[1], "name_hint") else params.pop('Rank').asnumpy()[0]
delta = params.pop(inputs[2].name_hint).asnumpy()[0]
start = _get_param(params, inputs[0])[0]
limit = _get_param(params, inputs[1])[0] \
if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \
else params.pop('Rank').asnumpy()[0]
delta = _get_param(params, inputs[2])[0]
dtype = attr['dtype'].name if 'dtype' in attr else "int32"
return AttrCvt(
op_name="arange",
......@@ -1084,7 +1087,7 @@ def _softplus():
def _topk():
def _impl(inputs, attr, params):
k = int(params.pop(inputs.pop(1).name_hint).asnumpy())
k = int(_get_num_param(params, inputs.pop(1)))
if k < 1:
raise tvm.error.OpAttributeInvalid(
'Attribute k must be positive in operator TopKV2')
......@@ -1196,7 +1199,7 @@ def _batch_to_space_nd():
def _prod():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy()[0]
axis = _get_num_param(params, inputs[1])
keepdims = attr['keep_dims']
return _op.prod(inputs[0], int(axis), keepdims=keepdims)
return _impl
......@@ -2104,13 +2107,12 @@ class GraphProto(object):
if array_ndim == 0:
new_array = np.empty([1], dtype=np_array.dtype)
new_array[0] = np_array
self._params[name] = tvm.nd.array(new_array)
self._nodes[name] = [tvm.relay.const(new_array)]
else:
self._params[name] = tvm.nd.array(np_array)
self._nodes[name] = [_expr.var(name,
shape=self._params[name].shape,
dtype=self._params[name].dtype)]
self._nodes[name] = [_expr.var(name,
shape=self._params[name].shape,
dtype=self._params[name].dtype)]
else:
if key not in ('dtype', '_output_shapes', '_class'):
raise NotImplementedError \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment