Commit 360d26dd by youluexx Committed by Yizhi Liu

[Relay][Frontend][darknet] Solve tvm parsing darknet resnext failure bug (#3778)

* test_darkent_bug

* test_darkent

* add resnext tests
parent 5fe61fd1
...@@ -231,4 +231,4 @@ conda/pkg ...@@ -231,4 +231,4 @@ conda/pkg
# antlr files # antlr files
*.tokens *.tokens
*.interp *.interp
\ No newline at end of file
...@@ -458,11 +458,11 @@ class GraphProto(object): ...@@ -458,11 +458,11 @@ class GraphProto(object):
if layer.nweights == 0: if layer.nweights == 0:
return None return None
if (layer.n * layer.c * layer.size * layer.size) != layer.nweights: if (layer.n * layer.c // layer.groups * layer.size * layer.size) != layer.nweights:
raise RuntimeError("layer weights size not matching with n c h w") raise RuntimeError("layer weights size not matching with n c h w")
params = {} params = {}
shape = (layer.n, layer.c, layer.size, layer.size) shape = (layer.n, layer.c // layer.groups, layer.size, layer.size)
weights = self._read_memory_buffer(shape, layer.weights) weights = self._read_memory_buffer(shape, layer.weights)
biases = self._read_memory_buffer((layer.n, ), layer.biases) biases = self._read_memory_buffer((layer.n, ), layer.biases)
......
...@@ -189,6 +189,18 @@ def test_forward_resnet50(): ...@@ -189,6 +189,18 @@ def test_forward_resnet50():
verify_darknet_frontend(net) verify_darknet_frontend(net)
LIB.free_network(net) LIB.free_network(net)
def test_forward_resnext50():
'''test resnet50 model'''
model_name = 'resnext50'
cfg_name = model_name + '.cfg'
weights_name = model_name + '.weights'
cfg_url = 'https://github.com/pjreddie/darknet/blob/master/cfg/' + cfg_name + '?raw=true'
weights_url = 'http://pjreddie.com/media/files/' + weights_name + '?raw=true'
net = _load_net(cfg_url, cfg_name, weights_url, weights_name)
verify_darknet_frontend(net)
LIB.free_network(net)
def test_forward_yolov2(): def test_forward_yolov2():
'''test yolov2 model''' '''test yolov2 model'''
model_name = 'yolov2' model_name = 'yolov2'
...@@ -441,6 +453,7 @@ def test_forward_rnn(): ...@@ -441,6 +453,7 @@ def test_forward_rnn():
if __name__ == '__main__': if __name__ == '__main__':
test_forward_resnet50() test_forward_resnet50()
test_forward_resnext50()
test_forward_alexnet() test_forward_alexnet()
test_forward_extraction() test_forward_extraction()
test_forward_yolov2() test_forward_yolov2()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment