Commit df7cc5db by lixiaoquan Committed by Zhi

[TENSORFLOW] Convert scalar Const into tvm.relay.const (#3885)

* [TENSORFLOW] Convert scalar Const into tvm.relay.const

* use _get_num_param() and _get_list_param()
parent 5ed251a6
...@@ -84,10 +84,12 @@ def _dimension_constraint(): ...@@ -84,10 +84,12 @@ def _dimension_constraint():
return _dim_check, "Only 2d kernel supported." return _dim_check, "Only 2d kernel supported."
def _get_param(params, input_node): def _get_param(params, input_node):
if isinstance(input_node, _expr.Constant):
return np.atleast_1d(input_node.data.asnumpy())
return params.pop(input_node.name_hint).asnumpy() return params.pop(input_node.name_hint).asnumpy()
def _get_num_param(params, input_node): def _get_num_param(params, input_node):
return _get_param(params, input_node)[0] return _get_param(params, input_node).item()
def _get_list_param(params, input_node): def _get_list_param(params, input_node):
return _get_param(params, input_node).tolist() return _get_param(params, input_node).tolist()
...@@ -335,9 +337,9 @@ def _crop_and_resize(): ...@@ -335,9 +337,9 @@ def _crop_and_resize():
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth] # input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
# boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2] # boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2]
try: try:
boxes = params.pop(inputs[1].name_hint).asnumpy().tolist() boxes = _get_list_param(params, inputs[1])
box_ind = params.pop(inputs[2].name_hint).asnumpy().tolist() box_ind = _get_list_param(params, inputs[2])
crop_size = params.pop(inputs[3].name_hint).asnumpy().tolist() crop_size = _get_list_param(params, inputs[3])
except (IndexError, KeyError): except (IndexError, KeyError):
boxes = _infer_value(inputs[1], params).asnumpy().tolist() boxes = _infer_value(inputs[1], params).asnumpy().tolist()
box_ind = _infer_value(inputs[2], params).asnumpy().tolist() box_ind = _infer_value(inputs[2], params).asnumpy().tolist()
...@@ -505,7 +507,7 @@ def _pack(): ...@@ -505,7 +507,7 @@ def _pack():
def _tile(): def _tile():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
reps = params[inputs.pop().name_hint].asnumpy() reps = _get_list_param(params, inputs.pop())
new_input = [] new_input = []
new_input.append(inputs.pop(0)) new_input.append(inputs.pop(0))
...@@ -752,7 +754,7 @@ def _sum(): ...@@ -752,7 +754,7 @@ def _sum():
def _reduce(op): def _reduce(op):
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy() axis = _get_list_param(params, inputs[1])
axis = tuple(axis) axis = tuple(axis)
return AttrCvt( return AttrCvt(
op_name=op, op_name=op,
...@@ -937,8 +939,8 @@ def _where(): ...@@ -937,8 +939,8 @@ def _where():
def _clip_by_value(): def _clip_by_value():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
a_min = params.pop(inputs[1].name_hint).asnumpy()[0] a_min = _get_num_param(params, inputs[1])
a_max = params.pop(inputs[2].name_hint).asnumpy()[0] a_max = _get_num_param(params, inputs[2])
return _op.clip(inputs[0], a_min=a_min, a_max=a_max) return _op.clip(inputs[0], a_min=a_min, a_max=a_max)
return _impl return _impl
...@@ -965,10 +967,11 @@ def _rank(): ...@@ -965,10 +967,11 @@ def _rank():
def _range(): def _range():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
start = params.pop(inputs[0].name_hint).asnumpy()[0] start = _get_param(params, inputs[0])[0]
limit = params.pop(inputs[1].name_hint).asnumpy()[0] \ limit = _get_param(params, inputs[1])[0] \
if hasattr(inputs[1], "name_hint") else params.pop('Rank').asnumpy()[0] if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \
delta = params.pop(inputs[2].name_hint).asnumpy()[0] else params.pop('Rank').asnumpy()[0]
delta = _get_param(params, inputs[2])[0]
dtype = attr['dtype'].name if 'dtype' in attr else "int32" dtype = attr['dtype'].name if 'dtype' in attr else "int32"
return AttrCvt( return AttrCvt(
op_name="arange", op_name="arange",
...@@ -1084,7 +1087,7 @@ def _softplus(): ...@@ -1084,7 +1087,7 @@ def _softplus():
def _topk(): def _topk():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
k = int(params.pop(inputs.pop(1).name_hint).asnumpy()) k = int(_get_num_param(params, inputs.pop(1)))
if k < 1: if k < 1:
raise tvm.error.OpAttributeInvalid( raise tvm.error.OpAttributeInvalid(
'Attribute k must be positive in operator TopKV2') 'Attribute k must be positive in operator TopKV2')
...@@ -1196,7 +1199,7 @@ def _batch_to_space_nd(): ...@@ -1196,7 +1199,7 @@ def _batch_to_space_nd():
def _prod(): def _prod():
def _impl(inputs, attr, params): def _impl(inputs, attr, params):
axis = params.pop(inputs[1].name_hint).asnumpy()[0] axis = _get_num_param(params, inputs[1])
keepdims = attr['keep_dims'] keepdims = attr['keep_dims']
return _op.prod(inputs[0], int(axis), keepdims=keepdims) return _op.prod(inputs[0], int(axis), keepdims=keepdims)
return _impl return _impl
...@@ -2104,13 +2107,12 @@ class GraphProto(object): ...@@ -2104,13 +2107,12 @@ class GraphProto(object):
if array_ndim == 0: if array_ndim == 0:
new_array = np.empty([1], dtype=np_array.dtype) new_array = np.empty([1], dtype=np_array.dtype)
new_array[0] = np_array new_array[0] = np_array
self._params[name] = tvm.nd.array(new_array) self._nodes[name] = [tvm.relay.const(new_array)]
else: else:
self._params[name] = tvm.nd.array(np_array) self._params[name] = tvm.nd.array(np_array)
self._nodes[name] = [_expr.var(name,
self._nodes[name] = [_expr.var(name, shape=self._params[name].shape,
shape=self._params[name].shape, dtype=self._params[name].dtype)]
dtype=self._params[name].dtype)]
else: else:
if key not in ('dtype', '_output_shapes', '_class'): if key not in ('dtype', '_output_shapes', '_class'):
raise NotImplementedError \ raise NotImplementedError \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment