Commit 4b827593 by 雾雨魔理沙 Committed by Wuwei Lin

[Relay] [Parser] fix parser for cast. (#3873)

* fix

* lint
parent 60de5be1
...@@ -77,7 +77,8 @@ class ExprOp(OpWrapper): ...@@ -77,7 +77,8 @@ class ExprOp(OpWrapper):
try: try:
return expr.Call(self.operator, args, attrs, type_args) return expr.Call(self.operator, args, attrs, type_args)
except Exception: except Exception:
raise Exception(str(self.operator) + " " + str(attrs)) raise Exception("Operator {} is not registered. It's attributes are {}"
.format(self.operator, attrs))
class FuncOp(OpWrapper): class FuncOp(OpWrapper):
"""Convert the attrs, call the python function with the attrs passed in as keyword arguments. """Convert the attrs, call the python function with the attrs passed in as keyword arguments.
...@@ -132,6 +133,7 @@ FUNC_OPS = { ...@@ -132,6 +133,7 @@ FUNC_OPS = {
"nn.dropout": op.nn.dropout_raw, "nn.dropout": op.nn.dropout_raw,
"zeros": op.zeros, "zeros": op.zeros,
"split": op.split, "split": op.split,
"cast": op.cast
} }
TYPE_PREFIXES = [ TYPE_PREFIXES = [
......
...@@ -169,19 +169,23 @@ def test_inception_v3(): ...@@ -169,19 +169,23 @@ def test_inception_v3():
net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1) net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1)
astext(net) astext(net)
def test_squeezenet(): def test_squeezenet():
for version in ['1.0', '1.1']: for version in ['1.0', '1.1']:
net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version) net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version)
astext(net) astext(net)
def test_vgg(): def test_vgg():
net, params = tvm.relay.testing.vgg.get_workload(batch_size=1) net, params = tvm.relay.testing.vgg.get_workload(batch_size=1)
astext(net) astext(net)
def test_densenet(): def test_densenet():
net, params = tvm.relay.testing.densenet.get_workload(batch_size=1) net, params = tvm.relay.testing.densenet.get_workload(batch_size=1)
astext(net) astext(net)
def test_call_node_order(): def test_call_node_order():
x = relay.var("x") x = relay.var("x")
y = relay.var("y") y = relay.var("y")
...@@ -196,6 +200,7 @@ def test_call_node_order(): ...@@ -196,6 +200,7 @@ def test_call_node_order():
"};\n" "};\n"
"%2(%1)") "%2(%1)")
def test_let_inlining(): def test_let_inlining():
tup = relay.Tuple([relay.const(0), relay.const(0)]) tup = relay.Tuple([relay.const(0), relay.const(0)])
x = relay.var("x") x = relay.var("x")
...@@ -208,10 +213,19 @@ def test_let_inlining(): ...@@ -208,10 +213,19 @@ def test_let_inlining():
("let %x = (0, 0);\n" ("let %x = (0, 0);\n"
"%x") "%x")
def test_zeros(): def test_zeros():
x = relay.op.zeros([], "float32") x = relay.op.zeros([], "float32")
astext(x) astext(x)
def test_cast():
data = relay.var('data', dtype='float32')
fp16_cast = relay.cast(data, dtype='float16')
cast_func = relay.Function(relay.analysis.free_vars(fp16_cast), fp16_cast)
astext(cast_func)
if __name__ == "__main__": if __name__ == "__main__":
do_print[0] = True do_print[0] = True
test_lstm() test_lstm()
...@@ -233,3 +247,4 @@ if __name__ == "__main__": ...@@ -233,3 +247,4 @@ if __name__ == "__main__":
test_let_if_scope() test_let_if_scope()
test_variable_name() test_variable_name()
test_call_node_order() test_call_node_order()
test_cast()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment