Commit 135ed1f6 by yuxguo

fix

parent 84297604
...@@ -413,7 +413,7 @@ class SharedGroupMLP(nn.Module): ...@@ -413,7 +413,7 @@ class SharedGroupMLP(nn.Module):
# mlps indicates different experts # mlps indicates different experts
if shared: if shared:
mlps = [MLPModel(group_input_dim, group_output_dim, mlps = [MLPModel(group_input_dim, group_output_dim,
hidden_dims=hidden_dims, flatten=False) for i in range(nr_mlps)] hidden_dims=hidden_dims) for i in range(nr_mlps)]
else: else:
mlps = [GroupMLP(group_input_dim * groups, group_output_dim * groups, mlps = [GroupMLP(group_input_dim * groups, group_output_dim * groups,
hidden_dims=hidden_dims) for i in range(nr_mlps)] hidden_dims=hidden_dims) for i in range(nr_mlps)]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment