Commit 05a5c170 by Zhi Committed by Tianqi Chen

return mod from frontend for autotvm (#3401)

parent 917ad9f6
......@@ -96,7 +96,8 @@ def get_network(name, batch_size):
# an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True)
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = mod[mod.entry_func]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
else:
raise ValueError("Unsupported network: " + name)
......
......@@ -96,7 +96,8 @@ def get_network(name, batch_size):
# an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True)
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = mod[mod.entry_func]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
else:
raise ValueError("Unsupported network: " + name)
......
......@@ -97,7 +97,8 @@ def get_network(name, batch_size):
# an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True)
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = mod[mod.entry_func]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
else:
raise ValueError("Unsupported network: " + name)
......
......@@ -64,7 +64,8 @@ def get_network(name, batch_size):
# an example for mxnet model
from mxnet.gluon.model_zoo.vision import get_model
block = get_model('resnet18_v1', pretrained=True)
net, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype)
net = mod[mod.entry_func]
net = relay.Function(net.params, relay.nn.softmax(net.body), None, net.type_params, net.attrs)
else:
raise ValueError("Unsupported network: " + name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment